Skip to content

Commit fa89ad7

Browse files
authored
Adding support for ForwardRef in CLI (#176)
* Adding support for ForwardRef Signed-off-by: Marc Romeyn <[email protected]> * Ruff formatting Signed-off-by: Marc Romeyn <[email protected]> * Adding support for ForwardRef Signed-off-by: Marc Romeyn <[email protected]> * Fixing ruff issues Signed-off-by: Marc Romeyn <[email protected]> * Add extra check for ForwardRef Signed-off-by: Marc Romeyn <[email protected]> * Trying to fix failing test Signed-off-by: Marc Romeyn <[email protected]> * Trying to fix failing test Signed-off-by: Marc Romeyn <[email protected]> * Fix failing test Signed-off-by: Marc Romeyn <[email protected]> * Fix linting issues Signed-off-by: Marc Romeyn <[email protected]> * Adding support for Optional[ForwardRef(..)] Signed-off-by: Marc Romeyn <[email protected]> * Put back normal pyproject.toml Signed-off-by: Marc Romeyn <[email protected]> * Fix bug Signed-off-by: Marc Romeijn <[email protected]> * Trying to fix failing test Signed-off-by: Marc Romeyn <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]> Signed-off-by: Marc Romeyn <[email protected]> Signed-off-by: Marc Romeijn <[email protected]>
1 parent 414f007 commit fa89ad7

File tree

6 files changed

+213
-14
lines changed

6 files changed

+213
-14
lines changed

nemo_run/cli/cli_parser.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
Union,
3737
get_args,
3838
get_origin,
39+
ForwardRef,
40+
Set,
41+
Tuple,
42+
FrozenSet,
3943
)
4044

4145
import fiddle as fdl
@@ -609,6 +613,7 @@ def __init__(self, strict_mode: bool = True):
609613
Optional: self.parse_optional,
610614
Literal: self.parse_literal,
611615
Path: self.parse_path,
616+
ForwardRef: self.parse_forward_ref,
612617
}
613618
self.custom_parsers = {}
614619
self.strict_mode = strict_mode
@@ -646,6 +651,8 @@ def get_parser(self, annotation: Type) -> Callable[[str, Type], Any]:
646651
Callable[[str, Type], Any]: The parser function for the given type.
647652
"""
648653
origin = get_origin(annotation) or annotation
654+
if str(origin).startswith("ForwardRef"):
655+
return self.parse_forward_ref
649656
return self.custom_parsers.get(origin) or self.parsers.get(origin) or self.parse_unknown
650657

651658
def parse(self, value: str, annotation: Type) -> Any:
@@ -950,6 +957,9 @@ def parse_path(self, value: str, _: Type) -> Path:
950957
raise ParseError(value, Path, "Invalid path: contains null character")
951958
return Path(value.strip("'\" "))
952959

960+
def parse_forward_ref(self, value: str, annotation) -> Any:
961+
return value
962+
953963
def infer_type(self, value: str) -> Type:
954964
"""Infer the type of a string value.
955965
@@ -980,7 +990,9 @@ def parse_value(value: str, annotation: Type = None) -> Any:
980990

981991
@cli_exception_handler
982992
def parse_cli_args(
983-
fn: Callable, args: List[str], output_type: Type[TypeVar("OutputT", Partial, Config)] = Partial
993+
fn: Callable,
994+
args: List[str],
995+
output_type: Type[TypeVar("OutputT", Partial, Config)] = Partial,
984996
) -> TypeVar("OutputT", Partial, Config):
985997
"""Parse command-line arguments and apply them to a function or class.
986998
@@ -1109,6 +1121,10 @@ def dummy_model_config():
11091121
annotation = param.annotation
11101122
logger.debug(f"Parsing value {value} as {annotation}")
11111123

1124+
annotation = _maybe_resolve_annotation(
1125+
getattr(nested, "__fn_or_cls__", nested), arg_name, annotation
1126+
)
1127+
11121128
if annotation:
11131129
try:
11141130
parsed_value = parse_factory(fn, arg_name, annotation, value)
@@ -1132,7 +1148,9 @@ def dummy_model_config():
11321148
else:
11331149
if not hasattr(nested, arg_name):
11341150
raise UndefinedVariableError(
1135-
f"Cannot use '{op.value}' on undefined variable", arg, {"key": key}
1151+
f"Cannot use '{op.value}' on undefined variable",
1152+
arg,
1153+
{"key": key},
11361154
)
11371155
setattr(
11381156
nested,
@@ -1277,7 +1295,9 @@ def _args_to_kwargs(fn: Callable, args: List[str]) -> List[str]:
12771295
for arg in args:
12781296
if "=" not in arg:
12791297
raise ArgumentParsingError(
1280-
"Positional argument found after keyword argument", arg, {"position": len(args)}
1298+
"Positional argument found after keyword argument",
1299+
arg,
1300+
{"position": len(args)},
12811301
)
12821302

12831303
return args
@@ -1304,7 +1324,9 @@ def _args_to_kwargs(fn: Callable, args: List[str]) -> List[str]:
13041324
positional_count += 1
13051325
else:
13061326
raise ArgumentParsingError(
1307-
"Too many positional arguments", arg, {"max_positional": len(params)}
1327+
"Too many positional arguments",
1328+
arg,
1329+
{"max_positional": len(params)},
13081330
)
13091331

13101332
return updated_args
@@ -1334,3 +1356,92 @@ def parse_attribute(attr, nested):
13341356
) from e
13351357

13361358
return result
1359+
1360+
1361+
def _maybe_resolve_annotation(fn: Callable, arg_name: str, annotation: Any) -> Any:
1362+
"""Internal function to resolve an annotation to its actual type.
1363+
1364+
This function handles string annotations, ForwardRef, and generic types (e.g., Optional, List)
1365+
by resolving string annotations within them, using TYPE_CHECKING blocks and their imports.
1366+
1367+
Args:
1368+
fn (Callable): The function containing the annotation
1369+
arg_name (str): The name of the parameter with the annotation
1370+
annotation (Any): The annotation to resolve (string, ForwardRef, or type)
1371+
1372+
Returns:
1373+
Any: The resolved type, or the original annotation if resolution fails
1374+
"""
1375+
# Case 1: Annotation is a string
1376+
if isinstance(annotation, str):
1377+
resolved = _resolve_type_checking_annotation(fn, annotation)
1378+
return resolved if resolved != annotation else annotation
1379+
1380+
# Case 2: Annotation is a ForwardRef
1381+
elif isinstance(annotation, ForwardRef):
1382+
return _resolve_type_checking_annotation(fn, annotation.__forward_arg__)
1383+
1384+
# Case 3: Annotation is a generic type (e.g., Optional, List, Union)
1385+
elif (origin := get_origin(annotation)) is not None:
1386+
args = get_args(annotation)
1387+
resolved_args = tuple(_maybe_resolve_annotation(fn, arg_name, arg) for arg in args)
1388+
if origin is list:
1389+
return List[resolved_args[0]]
1390+
elif origin is dict:
1391+
return Dict[resolved_args[0], resolved_args[1]]
1392+
elif origin is tuple:
1393+
return Tuple[resolved_args]
1394+
elif origin is set:
1395+
return Set[resolved_args[0]]
1396+
elif origin is frozenset:
1397+
return FrozenSet[resolved_args[0]]
1398+
elif origin is Union:
1399+
return Union[resolved_args]
1400+
else:
1401+
return annotation # Unhandled generic types return as-is
1402+
1403+
# Case 4: Annotation is a non-generic type (e.g., int, str)
1404+
else:
1405+
return annotation
1406+
1407+
1408+
def _resolve_type_checking_annotation(fn: Callable, annotation: str) -> Any:
1409+
"""Helper function to resolve a string annotation to its actual type using TYPE_CHECKING imports."""
1410+
if hasattr(fn, "__fn_or_cls__"):
1411+
fn = fn.__fn_or_cls__
1412+
1413+
try:
1414+
source_file = inspect.getsourcefile(fn)
1415+
if not source_file:
1416+
return annotation
1417+
with open(source_file, "r") as f:
1418+
source = f.read()
1419+
tree = ast.parse(source)
1420+
type_checking_imports = {}
1421+
for node in ast.walk(tree):
1422+
if (
1423+
isinstance(node, ast.If)
1424+
and isinstance(node.test, ast.Name)
1425+
and node.test.id == "TYPE_CHECKING"
1426+
):
1427+
for stmt in node.body:
1428+
if isinstance(stmt, (ast.Import, ast.ImportFrom)):
1429+
if isinstance(stmt, ast.Import):
1430+
for name in stmt.names:
1431+
type_checking_imports[name.asname or name.name] = name.name
1432+
else: # ImportFrom
1433+
module = stmt.module or ""
1434+
for name in stmt.names:
1435+
full_name = f"{module}.{name.name}" if module else name.name
1436+
type_checking_imports[name.asname or name.name] = full_name
1437+
if annotation in type_checking_imports:
1438+
try:
1439+
full_path = type_checking_imports[annotation]
1440+
module_name, type_name = full_path.rsplit(".", 1)
1441+
module = importlib.import_module(module_name)
1442+
return getattr(module, type_name)
1443+
except (ImportError, AttributeError):
1444+
pass
1445+
except Exception:
1446+
pass
1447+
return annotation

nemo_run/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def get_type_namespace(typ: Type | Callable) -> str:
9797
if isinstance(typ, fdl.Buildable):
9898
typ = typ.__fn_or_cls__
9999

100-
return f"{module}.{typ.__qualname__}"
100+
_name = getattr(typ, "__qualname__", str(typ))
101+
if _name.startswith("ForwardRef"):
102+
_name = _name.split(".")[-1]
103+
return f"{module}.{_name}"
101104

102105

103106
def get_underlying_types(type_hint: typing.Any) -> typing.Set[typing.Type]:

test/cli/test_api.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sys
1818
from configparser import ConfigParser
1919
from dataclasses import dataclass, field
20-
from typing import Annotated, List, Optional, Union
20+
from typing import Annotated, List, Optional, Union, TYPE_CHECKING
2121
from unittest.mock import Mock, patch
2222

2323
import fiddle as fdl
@@ -46,6 +46,10 @@
4646
from test.dummy_factory import DummyModel, dummy_entrypoint
4747
import nemo_run.cli.cli_parser # Import the module to mock its function
4848

49+
if TYPE_CHECKING:
50+
from test.dummy_type import RealType
51+
52+
4953
_RUN_FACTORIES_ENTRYPOINT: str = """
5054
[nemo_run.cli]
5155
dummy = test.dummy_factory
@@ -519,6 +523,42 @@ def test_factory_for_entrypoint(self):
519523
cfg = run.cli.resolve_factory(dummy_entrypoint, "dummy_recipe")()
520524
assert cfg.dummy.hidden == 2000
521525

526+
def test_forward_ref_with_real_type_factory(self):
527+
"""Test that ForwardRef works when factory is registered for the actual type."""
528+
529+
# Function that uses ForwardRef to the module-level RealType class
530+
def func(param: Optional["RealType"] = None):
531+
pass
532+
533+
from test.dummy_type import RealType as _RealType
534+
535+
# Register the factory in the module's global namespace
536+
# The factory returns a RealType instance with a specific value
537+
@run.cli.factory
538+
@run.autoconvert
539+
def real_type_factory() -> _RealType:
540+
return _RealType(value=100)
541+
542+
@run.cli.factory(target=func, target_arg="param")
543+
@run.autoconvert
544+
def other_factory() -> _RealType:
545+
return _RealType(value=200)
546+
547+
try:
548+
# Now test parsing works using the factory name
549+
result = cli_api.parse_cli_args(func, ["param=real_type_factory"])
550+
assert isinstance(result.param, run.Config)
551+
assert result.param.value == 100
552+
553+
result = cli_api.parse_cli_args(func, ["param=other_factory"])
554+
assert isinstance(result.param, run.Config)
555+
assert result.param.value == 200
556+
557+
finally:
558+
# Clean up - remove the factory from registry
559+
if hasattr(sys.modules[__name__], "real_type_factory"):
560+
delattr(sys.modules[__name__], "real_type_factory")
561+
522562

523563
class TestListEntrypoints:
524564
@dataclass

test/cli/test_cli_parser.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515

1616
import sys
1717
from pathlib import Path
18-
from typing import Any, Dict, List, Literal, Optional, Type, Union
18+
from typing import (
19+
Any,
20+
Dict,
21+
List,
22+
Literal,
23+
Optional,
24+
Type,
25+
Union,
26+
ForwardRef,
27+
)
1928

2029
import pytest
2130

@@ -156,6 +165,22 @@ def func(a: Literal["red", "green", "blue"]):
156165
assert "Error parsing argument" in str(exc_info.value)
157166
assert "Expected one of ('red', 'green', 'blue'), got 'yellow'" in str(exc_info.value)
158167

168+
def test_forward_ref_parsing(self):
169+
def func(tokenizer: Optional[ForwardRef("TokenizerSpec")]):
170+
pass
171+
172+
# Test with string value
173+
result = parse_cli_args(func, ["tokenizer=tokenizer_spec"])
174+
assert result.tokenizer.hidden == 1000
175+
176+
# Test with None
177+
result = parse_cli_args(func, ["tokenizer=None"])
178+
assert result.tokenizer is None
179+
180+
# Test with null (alternative None syntax)
181+
result = parse_cli_args(func, ["tokenizer=null"])
182+
assert result.tokenizer is None
183+
159184

160185
class TestFactoryFunctionParsing:
161186
def test_simple_factory_function(self):
@@ -351,7 +376,8 @@ def func(a: int):
351376
pass
352377

353378
with pytest.raises(
354-
ArgumentValueError, match="Invalid argument: No parameter named 'b' exists for"
379+
ArgumentValueError,
380+
match="Invalid argument: No parameter named 'b' exists for",
355381
):
356382
parse_cli_args(func, ["b=5"])
357383

@@ -401,7 +427,8 @@ def test_parse_int(self):
401427
assert parse_value("0", int) == 0
402428
assert parse_value("+789", int) == 789
403429
with pytest.raises(
404-
ParseError, match="Failed to parse '3.14' as <class 'int'>: Invalid integer literal"
430+
ParseError,
431+
match="Failed to parse '3.14' as <class 'int'>: Invalid integer literal",
405432
):
406433
parse_value("3.14", int)
407434
with pytest.raises(
@@ -450,7 +477,8 @@ def test_parse_bool(self):
450477
):
451478
parse_value("not_a_bool", bool)
452479
with pytest.raises(
453-
ParseError, match="Failed to parse '2' as <class 'bool'>: Cannot convert .* to bool"
480+
ParseError,
481+
match="Failed to parse '2' as <class 'bool'>: Cannot convert .* to bool",
454482
):
455483
parse_value("2", bool)
456484

@@ -468,7 +496,10 @@ def test_parse_list(self):
468496

469497
def test_parse_dict(self):
470498
assert parse_value('{"a": 1, "b": 2}', Dict[str, int]) == {"a": 1, "b": 2}
471-
assert parse_value('{"x": "foo", "y": "bar"}', Dict[str, str]) == {"x": "foo", "y": "bar"}
499+
assert parse_value('{"x": "foo", "y": "bar"}', Dict[str, str]) == {
500+
"x": "foo",
501+
"y": "bar",
502+
}
472503
assert parse_value("{}", Dict[str, Any]) == {}
473504
with pytest.raises(ParseError, match="Failed to parse 'not_a_dict' as typing.Dict"):
474505
parse_value("not_a_dict", Dict[str, int])
@@ -562,7 +593,8 @@ class CustomType:
562593
pass
563594

564595
with pytest.raises(
565-
ParseError, match="Failed to parse 'value' as <class '.*CustomType'>: Unsupported type"
596+
ParseError,
597+
match="Failed to parse 'value' as <class '.*CustomType'>: Unsupported type",
566598
):
567599
strict_parser.parse("value", CustomType)
568600

@@ -602,7 +634,11 @@ def test_parse_constructor(self, parser):
602634

603635
def test_parse_comprehension(self, parser):
604636
assert parser.parse_comprehension("[x for x in range(3)]") == [0, 1, 2]
605-
assert parser.parse_comprehension("{x: x**2 for x in range(3)}") == {0: 0, 1: 1, 2: 4}
637+
assert parser.parse_comprehension("{x: x**2 for x in range(3)}") == {
638+
0: 0,
639+
1: 1,
640+
2: 4,
641+
}
606642

607643
def test_parse_lambda(self, parser):
608644
# Test safe lambdas

test/dummy_factory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import List
17+
from typing import List, ForwardRef
1818

1919
import nemo_run as run
2020

@@ -99,6 +99,12 @@ def plugin_list(arg: int = 20) -> List[run.Plugin]:
9999
]
100100

101101

102+
@run.cli.factory
103+
@run.autoconvert
104+
def tokenizer_spec() -> ForwardRef("TokenizerSpec"):
105+
return DummyModel(hidden=1000)
106+
107+
102108
def dummy_train(dummy_model: DummyModel, dummy_trainer: DummyTrainer): ...
103109

104110

test/dummy_type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class RealType:
2+
def __init__(self, value=42):
3+
self.value = value

0 commit comments

Comments
 (0)