@@ -230,9 +230,18 @@ def help_for_type(
230230def class_to_str (class_obj ):
231231 if hasattr (class_obj , "__origin__" ):
232232 # Special handling for Optional types which are represented as Union[X, NoneType]
233- if class_obj . _name == "Optional" :
233+ if getattr ( class_obj , " _name" , None ) == "Optional" :
234234 args = class_to_str (typing .get_args (class_obj )[0 ])
235235 return f"Optional[{ args } ]"
236+ # Special handling for Union types
237+ elif getattr (class_obj , "_name" , None ) == "Union" :
238+ args = typing .get_args (class_obj )
239+ # Filter out NoneType from Union types
240+ args = [arg for arg in args if arg is not type (None )]
241+ if len (args ) == 1 :
242+ return class_to_str (args [0 ])
243+ else :
244+ return " | " .join (class_to_str (arg ) for arg in args )
236245 else :
237246 # Get the base type
238247 base = class_obj .__origin__ .__name__
@@ -260,13 +269,16 @@ def class_to_str(class_obj):
260269 # Handle Callable[[], return_type]
261270 return_type = class_to_str (args [0 ])
262271 return f"{ base } [[], { return_type } ]"
272+ else :
273+ # Handle bare Callable without type arguments
274+ return base
263275 else :
264276 # Handle other generic types
265277 args_str = ", " .join (class_to_str (arg ) for arg in args )
266278 return f"{ base } [{ args_str } ]"
267279 elif class_obj .__module__ == "builtins" :
268280 return class_obj .__name__
269- else :
281+ elif isinstance ( class_obj , type ) :
270282 module = _get_module (class_obj )
271283
272284 full_class_name = f"{ module } .{ class_obj .__name__ } "
@@ -288,6 +300,9 @@ def class_to_str(class_obj):
288300 return "nm.OptimizerModule"
289301
290302 return full_class_name
303+ else :
304+ # Handle non-type objects (like UnionType)
305+ return str (class_obj )
291306
292307
293308def help (
0 commit comments