Skip to content

Commit 5c23986

Browse files
Handle Callable in --help to fix nemo llm export --help error (#217)
* Handle Callable in --help to fix nemo llm export --help error Signed-off-by: ShriyaRishab <[email protected]> * Apply ruff formatting Signed-off-by: ShriyaRishab <[email protected]> --------- Signed-off-by: ShriyaRishab <[email protected]> Co-authored-by: ShriyaRishab <[email protected]>
1 parent 77c8ba3 commit 5c23986

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

nemo_run/help.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,34 @@ def class_to_str(class_obj):
236236
else:
237237
# Get the base type
238238
base = class_obj.__origin__.__name__
239-
# Get the arguments to the type if any (e.g., the 'str' in Optional[str])
240-
args = ", ".join(class_to_str(arg) for arg in typing.get_args(class_obj))
241-
return f"{base}[{args}]"
239+
# Get the arguments to the type if any
240+
args = typing.get_args(class_obj)
241+
242+
# Special handling for Callable types
243+
if base == "Callable":
244+
if len(args) == 2:
245+
if args[0] is ...:
246+
# Handle Callable[..., return_type]
247+
return_type = class_to_str(args[1])
248+
return f"{base}[..., {return_type}]"
249+
elif isinstance(args[0], list):
250+
# Handle Callable[[arg1, arg2], return_type]
251+
arg_types = ", ".join(class_to_str(arg) for arg in args[0])
252+
return_type = class_to_str(args[1])
253+
return f"{base}[[{arg_types}], {return_type}]"
254+
else:
255+
# Handle Callable[Protocol, return_type]
256+
arg_type = class_to_str(args[0])
257+
return_type = class_to_str(args[1])
258+
return f"{base}[{arg_type}, {return_type}]"
259+
elif len(args) == 1:
260+
# Handle Callable[[], return_type]
261+
return_type = class_to_str(args[0])
262+
return f"{base}[[], {return_type}]"
263+
else:
264+
# Handle other generic types
265+
args_str = ", ".join(class_to_str(arg) for arg in args)
266+
return f"{base}[{args_str}]"
242267
elif class_obj.__module__ == "builtins":
243268
return class_obj.__name__
244269
else:

0 commit comments

Comments
 (0)