Skip to content

Commit 0dfea6e

Browse files
committed
Adding support for ForwardRef
1 parent f3c8e99 commit 0dfea6e

File tree

4 files changed

+55
-3
lines changed

4 files changed

+55
-3
lines changed

nemo_run/cli/cli_parser.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def __init__(self, strict_mode: bool = True):
609609
Optional: self.parse_optional,
610610
Literal: self.parse_literal,
611611
Path: self.parse_path,
612+
'ForwardRef': self.parse_forward_ref, # Add ForwardRef handling
612613
}
613614
self.custom_parsers = {}
614615
self.strict_mode = strict_mode
@@ -646,6 +647,8 @@ def get_parser(self, annotation: Type) -> Callable[[str, Type], Any]:
646647
Callable[[str, Type], Any]: The parser function for the given type.
647648
"""
648649
origin = get_origin(annotation) or annotation
650+
if str(origin).startswith('ForwardRef'):
651+
return self.parse_forward_ref
649652
return self.custom_parsers.get(origin) or self.parsers.get(origin) or self.parse_unknown
650653

651654
def parse(self, value: str, annotation: Type) -> Any:
@@ -950,6 +953,23 @@ def parse_path(self, value: str, _: Type) -> Path:
950953
raise ParseError(value, Path, "Invalid path: contains null character")
951954
return Path(value.strip("'\" "))
952955

956+
def parse_forward_ref(self, value: str, annotation) -> Any:
957+
"""Parse a string value as a ForwardRef type.
958+
959+
Args:
960+
value (str): The string value to parse.
961+
annotation: The ForwardRef type annotation.
962+
963+
Returns:
964+
Any: The parsed value.
965+
966+
Raises:
967+
ParseError: If the value cannot be parsed.
968+
"""
969+
# For ForwardRef types, we'll just return the value as is
970+
# since the actual type resolution happens later
971+
return value
972+
953973
def infer_type(self, value: str) -> Type:
954974
"""Infer the type of a string value.
955975

nemo_run/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def get_type_namespace(typ: Type | Callable) -> str:
9797
if isinstance(typ, fdl.Buildable):
9898
typ = typ.__fn_or_cls__
9999

100-
return f"{module}.{typ.__qualname__}"
100+
_name = getattr(typ, "__qualname__", str(typ))
101+
if _name.startswith("ForwardRef"):
102+
_name = _name.split(".")[-1]
103+
return f"{module}.{_name}"
101104

102105

103106
def get_underlying_types(type_hint: typing.Any) -> typing.Set[typing.Type]:

test/cli/test_cli_parser.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sys
1717
from pathlib import Path
1818
from test.dummy_factory import DummyModel
19-
from typing import Any, Dict, List, Literal, Optional, Type, Union
19+
from typing import Any, Dict, List, Literal, Optional, Type, Union, ForwardRef, TYPE_CHECKING
2020

2121
import pytest
2222

@@ -37,6 +37,9 @@
3737
)
3838
from nemo_run.config import Config, Partial
3939

40+
if TYPE_CHECKING:
41+
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
42+
4043

4144
class TestSimpleValueParsing:
4245
def test_int_parsing(self):
@@ -153,6 +156,23 @@ def func(a: Literal["red", "green", "blue"]):
153156
assert "Error parsing argument" in str(exc_info.value)
154157
assert "Expected one of ('red', 'green', 'blue'), got 'yellow'" in str(exc_info.value)
155158

159+
def test_forward_ref_parsing(self):
160+
# Use string annotation to avoid needing the actual class at runtime
161+
def func(tokenizer: Optional[ForwardRef("TokenizerSpec")]):
162+
pass
163+
164+
# Test with string value
165+
result = parse_cli_args(func, ["tokenizer=tokenizer_spec"])
166+
assert result.tokenizer.hidden == 1000
167+
168+
# Test with None
169+
result = parse_cli_args(func, ["tokenizer=None"])
170+
assert result.tokenizer is None
171+
172+
# Test with null (alternative None syntax)
173+
result = parse_cli_args(func, ["tokenizer=null"])
174+
assert result.tokenizer is None
175+
156176

157177
class TestFactoryFunctionParsing:
158178
def test_simple_factory_function(self):

test/dummy_factory.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17-
from typing import List
17+
from typing import List, TYPE_CHECKING, ForwardRef
1818

1919
import nemo_run as run
2020

21+
if TYPE_CHECKING:
22+
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
23+
2124

2225
@dataclass
2326
class DummyModel:
@@ -99,6 +102,12 @@ def plugin_list(arg: int = 20) -> List[run.Plugin]:
99102
]
100103

101104

105+
@run.cli.factory
106+
@run.autoconvert
107+
def tokenizer_spec() -> ForwardRef("TokenizerSpec"):
108+
return DummyModel(hidden=1000)
109+
110+
102111
def dummy_train(dummy_model: DummyModel, dummy_trainer: DummyTrainer): ...
103112

104113

0 commit comments

Comments
 (0)