Skip to content

Commit 4a70b85

Browse files
Daniel Ohayonfacebook-github-bot
authored andcommitted
support dataclasses for component args
Summary: This change adds the ability to provide component arguments via a single dataclass (which can still be combined with positional varargs). One motivation for supporting dataclasses is to facilitate composition of components, eg consider a scenario where `component2` is an extension of `component1` – it takes the same args as component1 + some extra ones, calls `component1` and does some extra work on top with the additional arguments. ## Before ```lang=python def component1(arg1: str, arg2: int, ...argN: str): ... # need to repeat all arguments, unclear which ones are specific to component2 def component2(arg1: str, arg2: int, ... argN: str, otherArg: str): # need to pass all component1-specific args to component1 explicitly app_def = component1(arg1, arg2, ... argN) return do_something(app_def, otherArg) ``` ## After ```lang=python dataclass class Comp1Args: arg1: str arg2: int ... argN: str # separation of args for the 2 components is explicit via the dataclasses # no need to spell out all the args thanks to inheritance dataclass class Comp2Args(Comp1Args): other_arg: str def component1(args: Comp1Args): ... def component2(args: Comp2Args): # no need to spell out all the args when calling component1 app_def = component1(args) return do_something(app_def, other_arg) ``` Differential Revision: D75320316
1 parent 24dc0d5 commit 4a70b85

File tree

2 files changed

+120
-9
lines changed

2 files changed

+120
-9
lines changed

torchx/specs/builders.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
# pyre-strict
88

99
import argparse
10+
import dataclasses
1011
import inspect
1112
import os
1213
from argparse import Namespace
14+
from dataclasses import dataclass
1315
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
1416

1517
from torchx.specs.api import BindMount, MountType, VolumeMount
@@ -24,12 +26,74 @@ def _create_args_parser(
2426
cmpnt_defaults: Optional[Dict[str, str]] = None,
2527
config: Optional[Dict[str, Any]] = None,
2628
) -> argparse.ArgumentParser:
27-
parameters = inspect.signature(cmpnt_fn).parameters
29+
parameters = _get_params_from_component_signature(cmpnt_fn).parameters
2830
return _create_args_parser_from_parameters(
2931
cmpnt_fn, parameters, cmpnt_defaults, config
3032
)
3133

3234

35+
@dataclass
36+
class SignatureInfo:
37+
parameters: Mapping[str, inspect.Parameter]
38+
dataclass_type: type[object] | None
39+
40+
41+
def _get_params_from_component_signature(
42+
cmpnt_fn: Callable[..., AppDef]
43+
) -> SignatureInfo:
44+
parameters = inspect.signature(cmpnt_fn).parameters
45+
dataclass_type = _maybe_get_dataclass_type(parameters)
46+
if dataclass_type is not None:
47+
parameters = _flatten_dataclass_params(parameters)
48+
return SignatureInfo(parameters, dataclass_type)
49+
50+
51+
def _maybe_get_dataclass_type(
52+
parameters: Mapping[str, inspect.Parameter]
53+
) -> type[object] | None:
54+
if len(parameters) not in (1, 2):
55+
# only support a single dataclass or a single dataclass followed by a vararg
56+
return None
57+
params = list(parameters.values())
58+
first_param_type = params[0].annotation
59+
is_first_param_dataclass = dataclasses.is_dataclass(
60+
first_param_type
61+
) and isinstance(first_param_type, type)
62+
if not is_first_param_dataclass:
63+
return None
64+
if len(params) == 1:
65+
return first_param_type
66+
if len(params) == 2 and params[1].kind == inspect.Parameter.VAR_POSITIONAL:
67+
return first_param_type
68+
return None
69+
70+
71+
def _flatten_dataclass_params(
72+
parameters: Mapping[str, inspect.Parameter]
73+
) -> Mapping[str, inspect.Parameter]:
74+
result = {}
75+
76+
for param_name, param in parameters.items():
77+
param_type = param.annotation
78+
if not dataclasses.is_dataclass(param_type):
79+
result[param_name] = param
80+
continue
81+
else:
82+
result.update(
83+
{
84+
f.name: inspect.Parameter(
85+
f.name,
86+
inspect._ParameterKind.KEYWORD_ONLY,
87+
annotation=f.type,
88+
default=f.default,
89+
)
90+
for f in dataclasses.fields(param_type)
91+
}
92+
)
93+
94+
return result
95+
96+
3397
def _create_args_parser_from_parameters(
3498
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
3599
parameters: Mapping[str, inspect.Parameter],
@@ -69,7 +133,7 @@ def __call__(
69133
)
70134

71135
for param_name, parameter in parameters.items():
72-
param_desc = args_desc[parameter.name]
136+
param_desc = args_desc.get(parameter.name)
73137
args: Dict[str, Any] = {
74138
"help": param_desc,
75139
"type": get_argparse_param_type(parameter),
@@ -147,18 +211,27 @@ def materialize_appdef(
147211
config: Optional[Dict[str, Any]] = None,
148212
) -> AppDef:
149213
"""
150-
Creates an application by running user defined ``app_fn``.
214+
Creates an application by running a user-defined component function ``cmpnt_fn``.
151215
152-
``app_fn`` has the following restrictions:
153-
* Name must be ``app_fn``
216+
``cmpnt_fn`` has the following restrictions:
154217
* All arguments should be annotated
155218
* Supported argument types:
156219
- primitive: int, str, float
157220
- Dict[primitive, primitive]
158221
- List[primitive]
159222
- Optional[Dict[primitive, primitive]]
160223
- Optional[List[primitive]]
161-
* ``app_fn`` can define a vararg (*arg) at the end
224+
225+
The arguments can also be passed as a single dataclass, e.g.
226+
227+
@dataclass
228+
class Args:
229+
arg1: str
230+
arg2: Dict[str, int]
231+
232+
def cmpnt_fn(args: Args) -> AppDef: ...
233+
234+
* ``cmpnt_fn`` can define a vararg (*arg) at the end (this also works if the first argument is a dataclass)
162235
* There should be a docstring for the function that defines
163236
All arguments in a google-style format
164237
* There can be default values for the function arguments.
@@ -180,8 +253,9 @@ def materialize_appdef(
180253

181254
parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
182255

183-
parameters = inspect.signature(cmpnt_fn).parameters
184-
for param_name, parameter in parameters.items():
256+
signature_info = _get_params_from_component_signature(cmpnt_fn)
257+
258+
for param_name, parameter in signature_info.parameters.items():
185259
arg_value = getattr(parsed_args, param_name)
186260
parameter_type = parameter.annotation
187261
parameter_type = decode_optional(parameter_type)
@@ -197,6 +271,9 @@ def materialize_appdef(
197271
if len(var_arg) > 0 and var_arg[0] == "--":
198272
var_arg = var_arg[1:]
199273

274+
if signature_info.dataclass_type is not None:
275+
function_args = [signature_info.dataclass_type(**kwargs)]
276+
kwargs = {}
200277
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
201278
if not isinstance(appdef, AppDef):
202279
raise TypeError(

torchx/specs/test/builders_test.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import argparse
1010
import sys
1111
import unittest
12-
from dataclasses import asdict
12+
from dataclasses import asdict, dataclass
1313
from pathlib import Path
1414
from typing import Any, Dict, List, Optional, Tuple
1515
from unittest.mock import patch
@@ -130,6 +130,34 @@ def example_test_complex_fn(
130130
return AppDef(app_name, roles)
131131

132132

133+
@dataclass
134+
class ComplexArgs:
135+
app_name: str
136+
containers: List[str]
137+
roles_scripts: Dict[str, str]
138+
num_cpus: Optional[List[int]] = None
139+
num_gpus: Optional[Dict[str, int]] = None
140+
nnodes: int = 4
141+
first_arg: Optional[str] = None
142+
nested_arg: Optional[Dict[str, List[str]]] = None
143+
144+
145+
def example_test_complex_fn_dataclass_arg(
146+
args: ComplexArgs, *roles_args: str
147+
) -> AppDef:
148+
return example_test_complex_fn(
149+
args.app_name,
150+
args.containers,
151+
args.roles_scripts,
152+
args.num_cpus,
153+
args.num_gpus,
154+
args.nnodes,
155+
args.first_arg,
156+
args.nested_arg,
157+
*roles_args,
158+
)
159+
160+
133161
_TEST_VAR_ARGS: Optional[Tuple[object, ...]] = None
134162

135163

@@ -292,6 +320,12 @@ def test_load_from_fn_complex_all_args(self) -> None:
292320
actual_app = materialize_appdef(example_test_complex_fn, app_args)
293321
self.assert_apps(expected_app, actual_app)
294322

323+
def test_load_from_fn_complex_all_args_dataclass(self) -> None:
324+
expected_app = self._get_expected_app_with_all_args()
325+
app_args = self._get_app_args()
326+
actual_app = materialize_appdef(example_test_complex_fn_dataclass_arg, app_args)
327+
self.assert_apps(expected_app, actual_app)
328+
295329
def test_required_args(self) -> None:
296330
with patch.object(sys, "exit") as exit_mock:
297331
try:

0 commit comments

Comments
 (0)