|
20 | 20 | import sys |
21 | 21 | from dataclasses import dataclass, field |
22 | 22 | from functools import cache, wraps |
| 23 | +import typing |
23 | 24 | from typing import ( |
24 | 25 | Any, |
| 26 | + Annotated, |
25 | 27 | Callable, |
26 | 28 | Dict, |
27 | 29 | Generic, |
28 | 30 | List, |
29 | 31 | Literal, |
30 | 32 | Optional, |
31 | 33 | Protocol, |
| 34 | + Set, |
32 | 35 | Tuple, |
33 | 36 | Type, |
34 | 37 | TypeVar, |
| 38 | + Union, |
35 | 39 | get_args, |
36 | 40 | get_type_hints, |
37 | 41 | overload, |
|
62 | 66 | Partial, |
63 | 67 | get_nemorun_home, |
64 | 68 | get_type_namespace, |
65 | | - get_underlying_types, |
| 69 | + RECURSIVE_TYPES, |
66 | 70 | ) |
67 | 71 | from nemo_run.core.execution import LocalExecutor, SkypilotExecutor, SlurmExecutor |
68 | 72 | from nemo_run.core.execution.base import Executor |
@@ -475,7 +479,7 @@ def resolve_factory( |
475 | 479 | if isinstance(target, str): |
476 | 480 | fn = catalogue._get((target, name)) |
477 | 481 | else: |
478 | | - types = get_underlying_types(target) |
| 482 | + types = extract_constituent_types(target) |
479 | 483 | num_missing = 0 |
480 | 484 | for t in types: |
481 | 485 | _namespace = get_type_namespace(t) |
@@ -1827,6 +1831,132 @@ def _export_config(output_path: str, format: Optional[str] = None) -> None: |
1827 | 1831 | _export_config(to_json, format="json") |
1828 | 1832 |
|
1829 | 1833 |
|
| 1834 | +def extract_constituent_types(type_hint: Any) -> Set[Type]: |
| 1835 | + """ |
| 1836 | + Extract all constituent types from a type hint, including generics and their type arguments. |
| 1837 | + This function recursively traverses complex type hints to find all underlying types. |
| 1838 | +
|
| 1839 | + For example: |
| 1840 | + - For Union[int, str] -> {int, str} |
| 1841 | + - For List[int] -> {list, int} |
| 1842 | + - For Dict[str, Optional[int]] -> {dict, str, int} |
| 1843 | + - For Callable[..., int] -> {Callable} |
| 1844 | + - For TypeVar('T') -> {TypeVar} |
| 1845 | + - For Annotated[List[int], "metadata"] -> {list, int} |
| 1846 | + - For Optional[List[int]] -> {list, int} |
| 1847 | + - For a class MyClass -> {MyClass} |
| 1848 | +
|
| 1849 | + The function handles: |
| 1850 | + - Basic types (int, str, etc.) |
| 1851 | + - Generic types (List, Dict, etc.) |
| 1852 | + - Union and Optional types |
| 1853 | + - Annotated types (extracts the underlying type) |
| 1854 | + - TypeVars and ForwardRefs |
| 1855 | + - Callable types |
| 1856 | + - Custom classes |
| 1857 | +
|
| 1858 | + Args: |
| 1859 | + type_hint: A type hint to analyze. Can be any valid Python type annotation, |
| 1860 | + including complex nested types. |
| 1861 | +
|
| 1862 | + Returns: |
| 1863 | + A set of all constituent types found in the type hint. NoneType is excluded |
| 1864 | + from the results. |
| 1865 | +
|
| 1866 | + Note: |
| 1867 | + This function is particularly useful for type checking and reflection, |
| 1868 | + where you need to know all the possible types that could be involved |
| 1869 | + in a type annotation. |
| 1870 | + """ |
| 1871 | + # Special case for functions and classes - return the type itself |
| 1872 | + if inspect.isfunction(type_hint) or inspect.isclass(type_hint): |
| 1873 | + return {type_hint} |
| 1874 | + |
| 1875 | + # Handle older style type hints (_GenericAlias) |
| 1876 | + if hasattr(typing, "_GenericAlias") and isinstance(type_hint, typing._GenericAlias): # type: ignore |
| 1877 | + # Correctly handle Annotated by getting the first argument (the actual type) |
| 1878 | + if str(type_hint).startswith("typing.Annotated") or str(type_hint).startswith( |
| 1879 | + "typing_extensions.Annotated" |
| 1880 | + ): |
| 1881 | + # Recurse on the actual type, skipping metadata |
| 1882 | + return extract_constituent_types(type_hint.__args__[0]) |
| 1883 | + else: |
| 1884 | + origin = type_hint.__origin__ |
| 1885 | + |
| 1886 | + if origin in RECURSIVE_TYPES: |
| 1887 | + types = set() |
| 1888 | + for arg in type_hint.__args__: |
| 1889 | + # Add check to skip NoneType here as well |
| 1890 | + if arg is not type(None): |
| 1891 | + types.update(extract_constituent_types(arg)) |
| 1892 | + return types |
| 1893 | + # If not a recursive type handled above, treat it like a concrete generic |
| 1894 | + # Collect types from arguments |
| 1895 | + result = set() |
| 1896 | + for arg in type_hint.__args__: |
| 1897 | + if arg is not type( |
| 1898 | + None |
| 1899 | + ): # Also skip NoneType here for generics like list[Optional[int]] |
| 1900 | + result.update(extract_constituent_types(arg)) |
| 1901 | + # Add the origin itself (e.g., list, dict) |
| 1902 | + if isinstance(origin, type): |
| 1903 | + result.add(origin) |
| 1904 | + # Add the original type_hint if it's a specific generic instantiation (and not a Union/Optional) |
| 1905 | + if origin is not None and origin not in RECURSIVE_TYPES: |
| 1906 | + result.add(type_hint) # type_hint is the _GenericAlias itself |
| 1907 | + return result # Return collected types |
| 1908 | + |
| 1909 | + # Handle Python 3.9+ style type hints |
| 1910 | + origin = typing.get_origin(type_hint) |
| 1911 | + args = typing.get_args(type_hint) |
| 1912 | + |
| 1913 | + # Base case: no origin or args means it's a simple type |
| 1914 | + if origin is None: |
| 1915 | + if type_hint is type(None): |
| 1916 | + return set() |
| 1917 | + if isinstance(type_hint, type): |
| 1918 | + return {type_hint} |
| 1919 | + return {type_hint} # Return the hint itself if not a type (e.g., TypeVar) |
| 1920 | + |
| 1921 | + # Handle Annotated for Python 3.9+ |
| 1922 | + if origin is Annotated: |
| 1923 | + # Recurse on the actual type argument, skipping metadata |
| 1924 | + return extract_constituent_types(args[0]) |
| 1925 | + |
| 1926 | + # Union type (including Optional) |
| 1927 | + if origin is Union: |
| 1928 | + result = set() |
| 1929 | + for arg in args: |
| 1930 | + if arg is not type(None): # Skip NoneType in Unions |
| 1931 | + result.update(extract_constituent_types(arg)) |
| 1932 | + return result |
| 1933 | + |
| 1934 | + # List, Dict, etc. - collect types from arguments |
| 1935 | + result = set() |
| 1936 | + for arg in args: |
| 1937 | + result.update(extract_constituent_types(arg)) |
| 1938 | + |
| 1939 | + # Include the origin type itself if it's a class |
| 1940 | + # This handles both typing module types and Python 3.9+ built-in generic types |
| 1941 | + if isinstance(origin, type): |
| 1942 | + result.add(origin) |
| 1943 | + |
| 1944 | + # Add the original type_hint if it's a specific generic instantiation (and not a Union/Annotated) |
| 1945 | + if origin not in (None, Union, Annotated): |
| 1946 | + # type_hint is the original parameterized generic, e.g., List[int] |
| 1947 | + # Add it only if it's indeed a generic (origin of type_hint itself is not None) |
| 1948 | + if typing.get_origin(type_hint) is not None: |
| 1949 | + result.add(type_hint) |
| 1950 | + |
| 1951 | + # If no types were added, return the original type hint to preserve behavior |
| 1952 | + if ( |
| 1953 | + not result |
| 1954 | + ): # This covers cases like type_hint being a TypeVar that resulted in an empty set initially |
| 1955 | + return {type_hint} |
| 1956 | + |
| 1957 | + return result |
| 1958 | + |
| 1959 | + |
1830 | 1960 | if __name__ == "__main__": |
1831 | 1961 | app = create_cli() |
1832 | 1962 | app() |
0 commit comments