Skip to content

Commit 67de154

Browse files
authored
Roll back get_underlying_types change + introduce extract_constituent… (#223)
* Roll back get_underlying_types change + introduce extract_constituent_types Signed-off-by: Marc Romeyn <[email protected]> * Bring back test Signed-off-by: Marc Romeyn <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]>
1 parent 2b59fd5 commit 67de154

File tree

6 files changed

+198
-133
lines changed

6 files changed

+198
-133
lines changed

nemo_run/cli/api.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@
2020
import sys
2121
from dataclasses import dataclass, field
2222
from functools import cache, wraps
23+
import typing
2324
from typing import (
2425
Any,
26+
Annotated,
2527
Callable,
2628
Dict,
2729
Generic,
2830
List,
2931
Literal,
3032
Optional,
3133
Protocol,
34+
Set,
3235
Tuple,
3336
Type,
3437
TypeVar,
38+
Union,
3539
get_args,
3640
get_type_hints,
3741
overload,
@@ -62,7 +66,7 @@
6266
Partial,
6367
get_nemorun_home,
6468
get_type_namespace,
65-
get_underlying_types,
69+
RECURSIVE_TYPES,
6670
)
6771
from nemo_run.core.execution import LocalExecutor, SkypilotExecutor, SlurmExecutor
6872
from nemo_run.core.execution.base import Executor
@@ -475,7 +479,7 @@ def resolve_factory(
475479
if isinstance(target, str):
476480
fn = catalogue._get((target, name))
477481
else:
478-
types = get_underlying_types(target)
482+
types = extract_constituent_types(target)
479483
num_missing = 0
480484
for t in types:
481485
_namespace = get_type_namespace(t)
@@ -1827,6 +1831,132 @@ def _export_config(output_path: str, format: Optional[str] = None) -> None:
18271831
_export_config(to_json, format="json")
18281832

18291833

1834+
def extract_constituent_types(type_hint: Any) -> Set[Type]:
1835+
"""
1836+
Extract all constituent types from a type hint, including generics and their type arguments.
1837+
This function recursively traverses complex type hints to find all underlying types.
1838+
1839+
For example:
1840+
- For Union[int, str] -> {int, str}
1841+
- For List[int] -> {list, int}
1842+
- For Dict[str, Optional[int]] -> {dict, str, int}
1843+
- For Callable[..., int] -> {Callable}
1844+
- For TypeVar('T') -> {TypeVar}
1845+
- For Annotated[List[int], "metadata"] -> {list, int}
1846+
- For Optional[List[int]] -> {list, int}
1847+
- For a class MyClass -> {MyClass}
1848+
1849+
The function handles:
1850+
- Basic types (int, str, etc.)
1851+
- Generic types (List, Dict, etc.)
1852+
- Union and Optional types
1853+
- Annotated types (extracts the underlying type)
1854+
- TypeVars and ForwardRefs
1855+
- Callable types
1856+
- Custom classes
1857+
1858+
Args:
1859+
type_hint: A type hint to analyze. Can be any valid Python type annotation,
1860+
including complex nested types.
1861+
1862+
Returns:
1863+
A set of all constituent types found in the type hint. NoneType is excluded
1864+
from the results.
1865+
1866+
Note:
1867+
This function is particularly useful for type checking and reflection,
1868+
where you need to know all the possible types that could be involved
1869+
in a type annotation.
1870+
"""
1871+
# Special case for functions and classes - return the type itself
1872+
if inspect.isfunction(type_hint) or inspect.isclass(type_hint):
1873+
return {type_hint}
1874+
1875+
# Handle older style type hints (_GenericAlias)
1876+
if hasattr(typing, "_GenericAlias") and isinstance(type_hint, typing._GenericAlias): # type: ignore
1877+
# Correctly handle Annotated by getting the first argument (the actual type)
1878+
if str(type_hint).startswith("typing.Annotated") or str(type_hint).startswith(
1879+
"typing_extensions.Annotated"
1880+
):
1881+
# Recurse on the actual type, skipping metadata
1882+
return extract_constituent_types(type_hint.__args__[0])
1883+
else:
1884+
origin = type_hint.__origin__
1885+
1886+
if origin in RECURSIVE_TYPES:
1887+
types = set()
1888+
for arg in type_hint.__args__:
1889+
# Add check to skip NoneType here as well
1890+
if arg is not type(None):
1891+
types.update(extract_constituent_types(arg))
1892+
return types
1893+
# If not a recursive type handled above, treat it like a concrete generic
1894+
# Collect types from arguments
1895+
result = set()
1896+
for arg in type_hint.__args__:
1897+
if arg is not type(
1898+
None
1899+
): # Also skip NoneType here for generics like list[Optional[int]]
1900+
result.update(extract_constituent_types(arg))
1901+
# Add the origin itself (e.g., list, dict)
1902+
if isinstance(origin, type):
1903+
result.add(origin)
1904+
# Add the original type_hint if it's a specific generic instantiation (and not a Union/Optional)
1905+
if origin is not None and origin not in RECURSIVE_TYPES:
1906+
result.add(type_hint) # type_hint is the _GenericAlias itself
1907+
return result # Return collected types
1908+
1909+
# Handle Python 3.9+ style type hints
1910+
origin = typing.get_origin(type_hint)
1911+
args = typing.get_args(type_hint)
1912+
1913+
# Base case: no origin or args means it's a simple type
1914+
if origin is None:
1915+
if type_hint is type(None):
1916+
return set()
1917+
if isinstance(type_hint, type):
1918+
return {type_hint}
1919+
return {type_hint} # Return the hint itself if not a type (e.g., TypeVar)
1920+
1921+
# Handle Annotated for Python 3.9+
1922+
if origin is Annotated:
1923+
# Recurse on the actual type argument, skipping metadata
1924+
return extract_constituent_types(args[0])
1925+
1926+
# Union type (including Optional)
1927+
if origin is Union:
1928+
result = set()
1929+
for arg in args:
1930+
if arg is not type(None): # Skip NoneType in Unions
1931+
result.update(extract_constituent_types(arg))
1932+
return result
1933+
1934+
# List, Dict, etc. - collect types from arguments
1935+
result = set()
1936+
for arg in args:
1937+
result.update(extract_constituent_types(arg))
1938+
1939+
# Include the origin type itself if it's a class
1940+
# This handles both typing module types and Python 3.9+ built-in generic types
1941+
if isinstance(origin, type):
1942+
result.add(origin)
1943+
1944+
# Add the original type_hint if it's a specific generic instantiation (and not a Union/Annotated)
1945+
if origin not in (None, Union, Annotated):
1946+
# type_hint is the original parameterized generic, e.g., List[int]
1947+
# Add it only if it's indeed a generic (origin of type_hint itself is not None)
1948+
if typing.get_origin(type_hint) is not None:
1949+
result.add(type_hint)
1950+
1951+
# If no types were added, return the original type hint to preserve behavior
1952+
if (
1953+
not result
1954+
): # This covers cases like type_hint being a TypeVar that resulted in an empty set initially
1955+
return {type_hint}
1956+
1957+
return result
1958+
1959+
18301960
if __name__ == "__main__":
18311961
app = create_cli()
18321962
app()

nemo_run/cli/cli_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,8 @@ def parse_factory(parent: Type, arg_name: str, arg_type: Type, value: str) -> An
12551255
"""
12561256
import catalogue
12571257

1258-
from nemo_run.config import Partial, get_type_namespace, get_underlying_types
1258+
from nemo_run.config import Partial, get_type_namespace
1259+
from nemo_run.cli.api import extract_constituent_types
12591260

12601261
def _get_from_registry(val, annotation, name):
12611262
if catalogue.check_exists(get_type_namespace(annotation), val):
@@ -1297,7 +1298,7 @@ def parse_single_factory(factory_str):
12971298
try:
12981299
factory_fn = _get_from_registry(factory_name, parent, name=arg_name)
12991300
except catalogue.RegistryError:
1300-
types = get_underlying_types(arg_type, include_self=True)
1301+
types = extract_constituent_types(arg_type)
13011302
for t in types:
13021303
try:
13031304
factory_fn = _get_from_registry(factory_name, t, name=factory_name)

nemo_run/config.py

Lines changed: 9 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import typing
2525
from pathlib import Path
2626
from types import MappingProxyType
27-
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, Union, Set, get_args
27+
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, Union, get_args
2828

2929
import fiddle as fdl
3030
import fiddle._src.experimental.dataclasses as fdl_dc
@@ -103,109 +103,20 @@ def get_type_namespace(typ: Type | Callable) -> str:
103103
return f"{module}.{_name}"
104104

105105

106-
def get_underlying_types(type_hint: Any, include_self: bool = False) -> Set[Type]:
107-
"""
108-
Retrieve the underlying types from a type hint, handling generic types.
109-
110-
Args:
111-
type_hint: The type hint to analyze
112-
include_self: If True, include the type_hint itself in the result if it's a specific generic.
113-
114-
Returns:
115-
A set of all underlying types
116-
"""
117-
# Special case for functions and classes - return the type itself
118-
if inspect.isfunction(type_hint) or inspect.isclass(type_hint):
119-
return {type_hint}
120-
121-
# Handle older style type hints (_GenericAlias)
122-
if hasattr(typing, "_GenericAlias") and isinstance(type_hint, typing._GenericAlias): # type: ignore
123-
# Correctly handle Annotated by getting the first argument (the actual type)
124-
if str(type_hint).startswith("typing.Annotated") or str(type_hint).startswith(
125-
"typing_extensions.Annotated"
126-
):
127-
# Recurse on the actual type, skipping metadata
128-
return get_underlying_types(type_hint.__args__[0], include_self=include_self)
106+
def get_underlying_types(type_hint: typing.Any) -> typing.Set[typing.Type]:
107+
if isinstance(type_hint, typing._GenericAlias): # type: ignore
108+
if str(type_hint).startswith("typing.Annotated"):
109+
origin = type_hint.__origin__
110+
if hasattr(origin, "__origin__"):
111+
origin = origin.__origin__
129112
else:
130113
origin = type_hint.__origin__
131-
132114
if origin in RECURSIVE_TYPES:
133115
types = set()
134116
for arg in type_hint.__args__:
135-
# Add check to skip NoneType here as well
136-
if arg is not type(None):
137-
types.update(get_underlying_types(arg, include_self=include_self))
117+
types.update(get_underlying_types(arg))
138118
return types
139-
# If not a recursive type handled above, treat it like a concrete generic
140-
# Collect types from arguments
141-
result = set()
142-
for arg in type_hint.__args__:
143-
if arg is not type(
144-
None
145-
): # Also skip NoneType here for generics like list[Optional[int]]
146-
result.update(get_underlying_types(arg, include_self=include_self))
147-
# Add the origin itself (e.g., list, dict)
148-
if isinstance(origin, type):
149-
result.add(origin)
150-
# Add the original type_hint if it's a specific generic instantiation (and not a Union/Optional)
151-
if include_self and origin is not None and origin not in RECURSIVE_TYPES:
152-
result.add(type_hint) # type_hint is the _GenericAlias itself
153-
return result # Return collected types
154-
155-
# Handle Python 3.9+ style type hints
156-
origin = typing.get_origin(type_hint)
157-
args = typing.get_args(type_hint)
158-
159-
# Base case: no origin or args means it's a simple type
160-
if origin is None:
161-
if type_hint is type(None):
162-
return set()
163-
if isinstance(type_hint, type):
164-
return {type_hint}
165-
return {type_hint} # Return the hint itself if not a type (e.g., TypeVar)
166-
167-
# Handle Annotated for Python 3.9+
168-
if origin is Annotated:
169-
# Recurse on the actual type argument, skipping metadata
170-
return get_underlying_types(args[0], include_self=include_self)
171-
172-
# Union type (including Optional)
173-
if origin is typing.Union:
174-
result = set()
175-
for arg in args:
176-
if arg is not type(None): # Skip NoneType in Unions
177-
result.update(get_underlying_types(arg, include_self=include_self))
178-
return result
179-
180-
# List, Dict, etc. - collect types from arguments
181-
result = set()
182-
for arg in args:
183-
result.update(get_underlying_types(arg, include_self=include_self))
184-
185-
# Include the origin type itself if it's a class
186-
# This handles both typing module types and Python 3.9+ built-in generic types
187-
if isinstance(origin, type):
188-
result.add(origin)
189-
190-
# Add the original type_hint if it's a specific generic instantiation (and not a Union/Annotated)
191-
if (
192-
include_self
193-
and origin is not None
194-
and origin is not typing.Union
195-
and origin is not Annotated
196-
):
197-
# type_hint is the original parameterized generic, e.g., List[int]
198-
# Add it only if it's indeed a generic (origin of type_hint itself is not None)
199-
if typing.get_origin(type_hint) is not None:
200-
result.add(type_hint)
201-
202-
# If no types were added, return the original type hint to preserve behavior
203-
if (
204-
not result
205-
): # This covers cases like type_hint being a TypeVar that resulted in an empty set initially
206-
return {type_hint}
207-
208-
return result
119+
return {type_hint}
209120

210121

211122
def from_dict(raw_data: dict | list | str | float | int | bool, cls: Type[_T]) -> _T:

test/cli/test_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_load_workspace_file,
4343
_load_workspace,
4444
main as cli_main,
45+
extract_constituent_types,
4546
)
4647
from test.dummy_factory import DummyModel, dummy_entrypoint
4748
import nemo_run.cli.cli_parser # Import the module to mock its function
@@ -1861,3 +1862,32 @@ def test_format_help_with_docs_flag(
18611862
with patch("sys.argv", ["script.py", "test_cmd", "--help", "-d"]):
18621863
cmd.format_help(mock_ctx, mock_formatter)
18631864
entrypoint.help.assert_called_once_with(mock_console, with_docs=True)
1865+
1866+
1867+
class TestExtractConstituentTypes:
1868+
@pytest.mark.parametrize(
1869+
"type_hint, expected_types",
1870+
[
1871+
(int, {int}),
1872+
(str, {str}),
1873+
(bool, {bool}),
1874+
(float, {float}),
1875+
(list[int], {list, int, list[int]}),
1876+
(dict[str, float], {dict, str, float, dict[str, float]}),
1877+
(Union[int, str], {int, str}),
1878+
(Optional[int], {int}), # Optional[T] is Union[T, NoneType]
1879+
(list[Union[int, str]], {list, int, str, list[Union[int, str]]}),
1880+
(dict[str, list[int]], {dict, str, list, int, dict[str, list[int]], list[int]}),
1881+
(Optional[list[str]], {list, str, list[str]}),
1882+
(Annotated[int, "meta"], {int}),
1883+
(Annotated[list[str], "meta"], {list, str, list[str]}),
1884+
(Annotated[Optional[dict[str, bool]], "meta"], {dict, str, bool, dict[str, bool]}),
1885+
(Union[Annotated[int, "int_meta"], Annotated[str, "str_meta"]], {int, str}),
1886+
(DummyModel, {DummyModel}),
1887+
(Optional[DummyModel], {DummyModel}),
1888+
(list[DummyModel], {list, DummyModel, list[DummyModel]}),
1889+
],
1890+
)
1891+
def test_various_type_hints(self, type_hint, expected_types):
1892+
"""Test get_underlying_types with various type hints."""
1893+
assert extract_constituent_types(type_hint) == expected_types

test/run/torchx_backend/schedulers/test_slurm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def test_tunnel_log_iterator():
297297

298298

299299
@mock.patch("nemo_run.run.torchx_backend.schedulers.slurm.SLURM_JOB_DIRS", "mock_job_dirs_path")
300-
@pytest.mark.xfail
301300
def test_get_job_dirs():
302301
# Single test using direct file manipulation instead of complex mocks
303302
with tempfile.TemporaryDirectory() as temp_dir:

0 commit comments

Comments
 (0)