Skip to content

Commit a4b01fe

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(bugfix)(torchx/components) handle BinOp style optional when decoding component arguments
Summary: Fixes https://fb.workplace.com/groups/140700188041197/posts/1025813616196512/?comment_id=1025888432855697&reply_comment_id=1031032659007941 Adds handling for BinOp style (introduced in python-3.10 - **PEP 604**) when decoding component arguments from cli arguments. For example: ``` def my_component(env: dict[str, str] | None = None) -> AppDef: ... ``` Handles passing `torchx run foo.py:my_component --env FOO:BAR;BAR=BAZ` Differential Revision: D82856462
1 parent 6ab9f69 commit a4b01fe

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

torchx/specs/builders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _create_args_parser(
3939

4040

4141
def _create_args_parser_from_parameters(
42-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
42+
cmpnt_fn: Callable[..., Any],
4343
parameters: Mapping[str, inspect.Parameter],
4444
cmpnt_defaults: Optional[Dict[str, str]] = None,
4545
config: Optional[Dict[str, Any]] = None,
@@ -120,7 +120,7 @@ def _merge_config_values_with_args(
120120

121121

122122
def parse_args(
123-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
123+
cmpnt_fn: Callable[..., Any],
124124
cmpnt_args: List[str],
125125
cmpnt_defaults: Optional[Dict[str, Any]] = None,
126126
config: Optional[Dict[str, Any]] = None,
@@ -149,7 +149,7 @@ def parse_args(
149149

150150

151151
def component_args_from_str(
152-
cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type
152+
cmpnt_fn: Callable[..., Any],
153153
cmpnt_args: list[str],
154154
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
155155
config: Optional[Dict[str, Any]] = None,
@@ -238,7 +238,7 @@ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
238238

239239

240240
def materialize_appdef(
241-
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
241+
cmpnt_fn: Callable[..., Any],
242242
cmpnt_args: List[str],
243243
cmpnt_defaults: Optional[Dict[str, Any]] = None,
244244
config: Optional[Dict[str, Any]] = None,

torchx/specs/test/builders_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def example_test_complex_fn(
9393
nnodes: int = 4,
9494
first_arg: Optional[str] = None,
9595
nested_arg: Optional[Dict[str, List[str]]] = None,
96+
env: dict[str, str] | None = None,
9697
*roles_args: str,
9798
) -> AppDef:
9899
"""Creates complex application, testing all possible complex types
@@ -127,6 +128,7 @@ def example_test_complex_fn(
127128
args=args,
128129
resource=Resource(cpu=cpus, gpu=gpus, memMB=1),
129130
num_replicas=nnodes,
131+
env=env or {},
130132
)
131133
roles.append(role)
132134
return AppDef(app_name, roles)
@@ -193,6 +195,7 @@ def _get_expected_app_with_default(self) -> AppDef:
193195
4,
194196
None,
195197
None,
198+
None,
196199
*role_args,
197200
)
198201

@@ -220,6 +223,7 @@ def _get_expected_app_with_all_args(self) -> AppDef:
220223
8,
221224
"first_arg",
222225
None,
226+
{"FOO": "BAR", "HELLO": "WORLD"},
223227
*role_args,
224228
)
225229

@@ -240,6 +244,8 @@ def _get_app_args(self) -> List[str]:
240244
"8",
241245
"--first_arg",
242246
"first_arg",
247+
"--env",
248+
"FOO=BAR,HELLO=WORLD",
243249
"--",
244250
*role_args,
245251
]
@@ -256,6 +262,7 @@ def _get_expected_app_with_nested_objects(self) -> AppDef:
256262
8,
257263
"first_arg",
258264
defaults,
265+
None,
259266
*role_args,
260267
)
261268

torchx/util/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import inspect
1010
import re
11+
from types import UnionType
1112
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
1213

1314

@@ -234,10 +235,20 @@ def decode_optional(param_type: Any) -> Any:
234235
If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
235236
Otherwise returns ``param_type``
236237
"""
238+
237239
if not hasattr(param_type, "__origin__"):
238-
return param_type
240+
if isinstance(param_type, UnionType):
241+
# handle BinOp style Optional (e.g. `T | None`)
242+
if len(param_type.__args__) == 2 and param_type.__args__[1] is type(None):
243+
return param_type.__args__[0]
244+
else:
245+
return param_type
246+
else:
247+
return param_type
248+
239249
if param_type.__origin__ is not Union:
240250
return param_type
251+
241252
args = param_type.__args__
242253
if len(args) == 2 and args[1] is type(None):
243254
return args[0]

0 commit comments

Comments
 (0)