Skip to content

Commit 4802685

Browse files
committed
Handle more edge cases in --help
Signed-off-by: ShriyaRishab <[email protected]>
1 parent 7474649 commit 4802685

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

nemo_run/help.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,18 @@ def help_for_type(
230230
def class_to_str(class_obj):
231231
if hasattr(class_obj, "__origin__"):
232232
# Special handling for Optional types which are represented as Union[X, NoneType]
233-
if class_obj._name == "Optional":
233+
if getattr(class_obj, "_name", None) == "Optional":
234234
args = class_to_str(typing.get_args(class_obj)[0])
235235
return f"Optional[{args}]"
236+
# Special handling for Union types
237+
elif getattr(class_obj, "_name", None) == "Union":
238+
args = typing.get_args(class_obj)
239+
# Filter out NoneType from Union types
240+
args = [arg for arg in args if arg is not type(None)]
241+
if len(args) == 1:
242+
return class_to_str(args[0])
243+
else:
244+
return " | ".join(class_to_str(arg) for arg in args)
236245
else:
237246
# Get the base type
238247
base = class_obj.__origin__.__name__
@@ -260,13 +269,16 @@ def class_to_str(class_obj):
260269
# Handle Callable[[], return_type]
261270
return_type = class_to_str(args[0])
262271
return f"{base}[[], {return_type}]"
272+
else:
273+
# Handle bare Callable without type arguments
274+
return base
263275
else:
264276
# Handle other generic types
265277
args_str = ", ".join(class_to_str(arg) for arg in args)
266278
return f"{base}[{args_str}]"
267279
elif class_obj.__module__ == "builtins":
268280
return class_obj.__name__
269-
else:
281+
elif isinstance(class_obj, type):
270282
module = _get_module(class_obj)
271283

272284
full_class_name = f"{module}.{class_obj.__name__}"
@@ -288,6 +300,9 @@ def class_to_str(class_obj):
288300
return "nm.OptimizerModule"
289301

290302
return full_class_name
303+
else:
304+
# Handle non-type objects (like UnionType)
305+
return str(class_obj)
291306

292307

293308
def help(

0 commit comments

Comments
 (0)