Skip to content

Commit a6ced3a

Browse files
ishachirimarfacebook-github-bot
authored andcommitted
Support JSON input to torchx run (#1107)
Summary: add support for piping a json into torchx run command, like `JSON | torchx run` where the JSON contains the args for torchx run (while maintaining the functionality of the regular CLI) This can be done with either a JSON directly in the terminal or `cat json_file.json`. Next steps: - add a json schema to improve authoring experience for this new config format - more robust validation - accept varargs for component args/ test this case Reviewed By: daniel-ohayon Differential Revision: D80731827
1 parent ca48a7c commit a6ced3a

File tree

5 files changed

+212
-31
lines changed

5 files changed

+212
-31
lines changed

torchx/cli/cmd_run.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import argparse
10+
import json
1011
import logging
1112
import os
1213
import sys
@@ -41,6 +42,12 @@
4142
"missing component name, either provide it from the CLI or in .torchxconfig"
4243
)
4344

45+
LOCAL_SCHEDULER_WARNING_MSG = (
46+
"`local` scheduler is deprecated and will be"
47+
" removed in the near future,"
48+
" please use other variants of the local scheduler"
49+
" (e.g. `local_cwd`)"
50+
)
4451

4552
logger: logging.Logger = logging.getLogger(__name__)
4653

@@ -256,35 +263,35 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
256263
default=False,
257264
help="Add additional prefix to log lines to indicate which replica is printing the log",
258265
)
266+
subparser.add_argument(
267+
"--stdin",
268+
action="store_true",
269+
default=False,
270+
help="Read JSON input from stdin to parse into torchx run args and run the component.",
271+
)
259272
subparser.add_argument(
260273
"component_name_and_args",
261274
nargs=argparse.REMAINDER,
262275
)
263276

264-
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
277+
def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
265278
if args.scheduler == "local":
266-
logger.warning(
267-
"`local` scheduler is deprecated and will be"
268-
" removed in the near future,"
269-
" please use other variants of the local scheduler"
270-
" (e.g. `local_cwd`)"
271-
)
272-
273-
cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args))
274-
config.apply(scheduler=args.scheduler, cfg=cfg)
279+
logger.warning(LOCAL_SCHEDULER_WARNING_MSG)
275280

276-
component, component_args = _parse_component_name_and_args(
277-
args.component_name_and_args,
278-
none_throws(self._subparser),
281+
config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg)
282+
component_args = (
283+
args.component_args_str
284+
if args.component_args_str != []
285+
else args.component_args
279286
)
280287
try:
281288
if args.dryrun:
282289
dryrun_info = runner.dryrun_component(
283-
component,
290+
args.component_name,
284291
component_args,
285292
args.scheduler,
286293
workspace=args.workspace,
287-
cfg=cfg,
294+
cfg=args.scheduler_cfg,
288295
parent_run_id=args.parent_run_id,
289296
)
290297
print(
@@ -295,11 +302,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
295302
print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}")
296303
else:
297304
app_handle = runner.run_component(
298-
component,
305+
args.component_name,
299306
component_args,
300307
args.scheduler,
301308
workspace=args.workspace,
302-
cfg=cfg,
309+
cfg=args.scheduler_cfg,
303310
parent_run_id=args.parent_run_id,
304311
)
305312
# DO NOT delete this line. It is used by slurm tests to retrieve the app id
@@ -320,7 +327,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
320327
)
321328

322329
except (ComponentValidationException, ComponentNotFoundException) as e:
323-
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
330+
error_msg = (
331+
f"\nFailed to run component `{args.component_name}` got errors: \n {e}"
332+
)
324333
logger.error(error_msg)
325334
sys.exit(1)
326335
except specs.InvalidRunConfigException as e:
@@ -335,6 +344,86 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
335344
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
336345
sys.exit(1)
337346

347+
def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None:
348+
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
349+
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
350+
351+
component, component_args = _parse_component_name_and_args(
352+
args.component_name_and_args,
353+
none_throws(self._subparser),
354+
)
355+
torchx_run_args = torchx_run_args_from_argparse(
356+
args, component, component_args, cfg
357+
)
358+
self._run_inner(runner, torchx_run_args)
359+
360+
def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None:
361+
torchx_run_args = torchx_run_args_from_json(stdin_data)
362+
scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler)
363+
cfg = scheduler_opts.cfg_from_json_repr(
364+
json.dumps(torchx_run_args.scheduler_args)
365+
)
366+
torchx_run_args.scheduler_cfg = cfg
367+
self._run_inner(runner, torchx_run_args)
368+
369+
def torchx_json_from_stdin(self) -> Dict[str, Any]:
370+
try:
371+
stdin_data_json = json.load(sys.stdin)
372+
if not isinstance(stdin_data_json, dict):
373+
logger.error(
374+
"Invalid JSON input for `torchx run` command. Expected a dictionary."
375+
)
376+
sys.exit(1)
377+
return stdin_data_json
378+
except (json.JSONDecodeError, EOFError):
379+
logger.error(
380+
"Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
381+
)
382+
sys.exit(1)
383+
384+
def verify_no_extra_args(self, args: argparse.Namespace) -> None:
385+
"""
386+
Verifies that only --stdin was provided when using stdin mode.
387+
"""
388+
if not args.stdin:
389+
return
390+
391+
subparser = none_throws(self._subparser)
392+
conflicting_args = []
393+
394+
# Check each argument against its default value
395+
for action in subparser._actions:
396+
if action.dest == "stdin": # Skip stdin itself
397+
continue
398+
if action.dest == "help": # Skip help
399+
continue
400+
401+
current_value = getattr(args, action.dest, None)
402+
default_value = action.default
403+
404+
# For arguments that differ from default
405+
if current_value != default_value:
406+
# Handle special cases where non-default doesn't mean explicitly set
407+
if action.dest == "component_name_and_args" and current_value == []:
408+
continue # Empty list is still default
409+
410+
conflicting_args.append(f"--{action.dest.replace('_', '-')}")
411+
412+
if conflicting_args:
413+
subparser.error(
414+
f"Cannot specify {', '.join(conflicting_args)} when using --stdin. "
415+
"All configuration should be provided in JSON input."
416+
)
417+
418+
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
419+
# Verify no conflicting arguments when using to loop over the stdin
420+
self.verify_no_extra_args(args)
421+
if args.stdin:
422+
stdin_data_json = self.torchx_json_from_stdin()
423+
self._run_from_stdin_args(runner, stdin_data_json)
424+
else:
425+
self._run_from_cli_args(runner, args)
426+
338427
def run(self, args: argparse.Namespace) -> None:
339428
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
340429
component_defaults = load_sections(prefix="component")

torchx/cli/test/cmd_run_test.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import argparse
1111
import dataclasses
1212
import io
13-
1413
import os
1514
import shutil
1615
import signal
@@ -67,14 +66,12 @@ def tearDown(self) -> None:
6766
torchxconfig.called_args = set()
6867

6968
def test_run_with_multiple_scheduler_args(self) -> None:
70-
7169
args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"]
7270
with self.assertRaises(SystemExit) as cm:
7371
self.parser.parse_args(args)
7472
self.assertEqual(cm.exception.code, 1)
7573

7674
def test_run_with_multiple_schedule_args(self) -> None:
77-
7875
args = [
7976
"--scheduler",
8077
"local_cwd",
@@ -179,13 +176,13 @@ def test_conf_file_missing(self) -> None:
179176
with patch(
180177
"torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir]
181178
):
179+
args = self.parser.parse_args(
180+
[
181+
"--scheduler",
182+
"local_cwd",
183+
]
184+
)
182185
with self.assertRaises(SystemExit):
183-
args = self.parser.parse_args(
184-
[
185-
"--scheduler",
186-
"local_cwd",
187-
]
188-
)
189186
self.cmd_run.run(args)
190187

191188
@patch("torchx.runner.Runner.run")
@@ -364,6 +361,91 @@ def test_parse_component_name_and_args_with_default(self) -> None:
364361
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
365362
)
366363

364+
def test_verify_no_extra_args_stdin_only(self) -> None:
365+
"""Test that only --stdin is allowed when using stdin mode."""
366+
args = self.parser.parse_args(["--stdin"])
367+
# Should not raise any exception
368+
self.cmd_run.verify_no_extra_args(args)
369+
370+
def test_verify_no_extra_args_no_stdin(self) -> None:
371+
"""Test that verification is skipped when not using stdin."""
372+
args = self.parser.parse_args(["--scheduler", "local_cwd", "utils.echo"])
373+
# Should not raise any exception
374+
self.cmd_run.verify_no_extra_args(args)
375+
376+
def test_verify_no_extra_args_stdin_with_component_name(self) -> None:
377+
"""Test that component name conflicts with stdin."""
378+
args = self.parser.parse_args(["--stdin", "utils.echo"])
379+
with self.assertRaises(SystemExit):
380+
self.cmd_run.verify_no_extra_args(args)
381+
382+
def test_verify_no_extra_args_stdin_with_scheduler_args(self) -> None:
383+
"""Test that scheduler_args conflicts with stdin."""
384+
args = self.parser.parse_args(["--stdin", "--scheduler_args", "cluster=test"])
385+
with self.assertRaises(SystemExit):
386+
self.cmd_run.verify_no_extra_args(args)
387+
388+
def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:
389+
"""Test that non-default scheduler conflicts with stdin."""
390+
args = self.parser.parse_args(["--stdin", "--scheduler", "kubernetes"])
391+
with self.assertRaises(SystemExit):
392+
self.cmd_run.verify_no_extra_args(args)
393+
394+
def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
395+
"""Test that boolean flags conflict with stdin."""
396+
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
397+
for flag in boolean_flags:
398+
args = self.parser.parse_args(["--stdin", flag])
399+
with self.assertRaises(SystemExit):
400+
self.cmd_run.verify_no_extra_args(args)
401+
402+
def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
403+
"""Test that arguments with values conflict with stdin."""
404+
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
405+
with self.assertRaises(SystemExit):
406+
self.cmd_run.verify_no_extra_args(args)
407+
408+
args = self.parser.parse_args(["--stdin", "--parent_run_id", "experiment_123"])
409+
with self.assertRaises(SystemExit):
410+
self.cmd_run.verify_no_extra_args(args)
411+
412+
def test_verify_no_extra_args_stdin_with_multiple_conflicts(self) -> None:
413+
"""Test that multiple conflicting arguments with stdin are detected."""
414+
args = self.parser.parse_args(
415+
["--stdin", "--dryrun", "--wait", "--scheduler_args", "cluster=test"]
416+
)
417+
with self.assertRaises(SystemExit):
418+
self.cmd_run.verify_no_extra_args(args)
419+
420+
def test_verify_no_extra_args_stdin_with_default_scheduler(self) -> None:
421+
"""Test that using default scheduler with stdin doesn't conflict."""
422+
# Get the default scheduler and use it explicitly - should not conflict
423+
from torchx.schedulers import get_default_scheduler_name
424+
425+
default_scheduler = get_default_scheduler_name()
426+
427+
args = self.parser.parse_args(["--stdin", "--scheduler", default_scheduler])
428+
# Should not raise any exception since it's the same as default
429+
self.cmd_run.verify_no_extra_args(args)
430+
431+
@patch("torchx.cli.cmd_run.Path.cwd")
432+
def test_verify_no_extra_args_stdin_with_default_workspace(
433+
self, mock_cwd: MagicMock
434+
) -> None:
435+
"""Test that using default workspace with stdin doesn't conflict."""
436+
mock_path = Path("mock/workspace/path")
437+
mock_cwd.return_value = mock_path
438+
439+
# Create a fresh parser with the mocked Path.cwd()
440+
fresh_parser = argparse.ArgumentParser()
441+
fresh_cmd_run = CmdRun()
442+
fresh_cmd_run.add_arguments(fresh_parser)
443+
444+
default_workspace = f"file://{mock_path}"
445+
args = fresh_parser.parse_args(["--stdin", "--workspace", default_workspace])
446+
# Should not raise any exception since it's the same as default
447+
fresh_cmd_run.verify_no_extra_args(args)
448+
367449

368450
class CmdBuiltinTest(unittest.TestCase):
369451
def test_run(self) -> None:

torchx/runner/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def close(self) -> None:
167167
def run_component(
168168
self,
169169
component: str,
170-
component_args: List[str],
170+
component_args: list[str] | dict[str, Any],
171171
scheduler: str,
172172
cfg: Optional[Mapping[str, CfgVal]] = None,
173173
workspace: Optional[str] = None,
@@ -226,7 +226,7 @@ def run_component(
226226
def dryrun_component(
227227
self,
228228
component: str,
229-
component_args: List[str],
229+
component_args: list[str] | dict[str, Any],
230230
scheduler: str,
231231
cfg: Optional[Mapping[str, CfgVal]] = None,
232232
workspace: Optional[str] = None,
@@ -237,10 +237,13 @@ def dryrun_component(
237237
component, but just returns what "would" have run.
238238
"""
239239
component_def = get_component(component)
240+
args_from_cli = component_args if isinstance(component_args, list) else []
241+
args_from_json = component_args if isinstance(component_args, dict) else {}
240242
app = materialize_appdef(
241243
component_def.fn,
242-
component_args,
244+
args_from_cli,
243245
self._component_defaults.get(component, None),
246+
args_from_json,
244247
)
245248
return self.dryrun(
246249
app,

torchx/specs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,8 @@ def gpu_x_1() -> Dict[str, Resource]:
225225
"make_app_handle",
226226
"materialize_appdef",
227227
"parse_mounts",
228+
"torchx_run_args_from_argparse",
229+
"torchx_run_args_from_json",
230+
"TorchXRunArgs",
228231
"ALL",
229232
]

torchx/specs/builders.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,11 @@ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
213213
arg_value = getattr(parsed_args, param_name)
214214
parameter_type = parameter.annotation
215215
parameter_type = decode_optional(parameter_type)
216-
arg_value = decode(arg_value, parameter_type)
216+
if (
217+
parameter_type != arg_value.__class__
218+
and parameter.kind != inspect.Parameter.VAR_POSITIONAL
219+
):
220+
arg_value = decode(arg_value, parameter_type)
217221
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
218222
var_args = arg_value
219223
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:

0 commit comments

Comments
 (0)