Skip to content

Commit d462df9

Browse files
ishachirimarfacebook-github-bot
authored andcommitted
Support JSON input to torchx run
Summary: add support for piping a json into torchx run command, like `JSON | torchx run` where the JSON contains the args for torchx run (while maintaining the functionality of the regular CLI) This can be done with either a JSON directly in the terminal or `cat json_file.json`. Next steps: - add a json schema to improve authoring experience for this new config format - more robust validation - accept varargs for component args/ test this case Reviewed By: daniel-ohayon Differential Revision: D80731827
1 parent 04f76e8 commit d462df9

File tree

4 files changed

+196
-32
lines changed

4 files changed

+196
-32
lines changed

torchx/cli/cmd_run.py

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import argparse
10+
import json
1011
import logging
1112
import os
1213
import sys
@@ -41,6 +42,12 @@
4142
"missing component name, either provide it from the CLI or in .torchxconfig"
4243
)
4344

45+
LOCAL_SCHEDULER_WARNING_MSG = (
46+
"`local` scheduler is deprecated and will be"
47+
" removed in the near future,"
48+
" please use other variants of the local scheduler"
49+
" (e.g. `local_cwd`)"
50+
)
4451

4552
logger: logging.Logger = logging.getLogger(__name__)
4653

@@ -256,35 +263,31 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
256263
default=False,
257264
help="Add additional prefix to log lines to indicate which replica is printing the log",
258265
)
266+
subparser.add_argument(
267+
"--stdin",
268+
action="store_true",
269+
default=False,
270+
help="Read JSON input from stdin to parse into torchx run args and run the component.",
271+
)
259272
subparser.add_argument(
260273
"component_name_and_args",
261274
nargs=argparse.REMAINDER,
262275
)
263276

264-
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
277+
def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
265278
if args.scheduler == "local":
266-
logger.warning(
267-
"`local` scheduler is deprecated and will be"
268-
" removed in the near future,"
269-
" please use other variants of the local scheduler"
270-
" (e.g. `local_cwd`)"
271-
)
279+
logger.warning(LOCAL_SCHEDULER_WARNING_MSG)
272280

273-
cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args))
274-
config.apply(scheduler=args.scheduler, cfg=cfg)
281+
config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg)
275282

276-
component, component_args = _parse_component_name_and_args(
277-
args.component_name_and_args,
278-
none_throws(self._subparser),
279-
)
280283
try:
281284
if args.dryrun:
282285
dryrun_info = runner.dryrun_component(
283-
component,
284-
component_args,
286+
args.component_name,
287+
args.component_args,
285288
args.scheduler,
286289
workspace=args.workspace,
287-
cfg=cfg,
290+
cfg=args.scheduler_cfg,
288291
parent_run_id=args.parent_run_id,
289292
)
290293
print(
@@ -295,11 +298,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
295298
print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}")
296299
else:
297300
app_handle = runner.run_component(
298-
component,
299-
component_args,
301+
args.component_name,
302+
args.component_args_str,
300303
args.scheduler,
301304
workspace=args.workspace,
302-
cfg=cfg,
305+
cfg=args.scheduler_cfg,
303306
parent_run_id=args.parent_run_id,
304307
)
305308
# DO NOT delete this line. It is used by slurm tests to retrieve the app id
@@ -320,7 +323,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
320323
)
321324

322325
except (ComponentValidationException, ComponentNotFoundException) as e:
323-
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
326+
error_msg = (
327+
f"\nFailed to run component `{args.component_name}` got errors: \n {e}"
328+
)
324329
logger.error(error_msg)
325330
sys.exit(1)
326331
except specs.InvalidRunConfigException as e:
@@ -335,6 +340,87 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
335340
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
336341
sys.exit(1)
337342

343+
def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None:
344+
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
345+
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
346+
347+
component, component_args = _parse_component_name_and_args(
348+
args.component_name_and_args,
349+
none_throws(self._subparser),
350+
)
351+
torchx_run_args = torchx_run_args_from_argparse(
352+
args, component, component_args, cfg
353+
)
354+
self._run_inner(runner, torchx_run_args)
355+
356+
def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None:
357+
torchx_run_args = torchx_run_args_from_json(stdin_data)
358+
scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler)
359+
cfg = scheduler_opts.cfg_from_json_repr(
360+
json.dumps(torchx_run_args.scheduler_args)
361+
)
362+
torchx_run_args.scheduler_cfg = cfg
363+
self._run_inner(runner, torchx_run_args)
364+
365+
def torchx_json_from_stdin(self) -> Dict[str, Any]:
366+
try:
367+
stdin_data_json = json.load(sys.stdin)
368+
if not isinstance(stdin_data_json, dict):
369+
logger.error(
370+
"Invalid JSON input for `torchx run` command. Expected a dictionary."
371+
)
372+
sys.exit(1)
373+
return stdin_data_json
374+
except (json.JSONDecodeError, EOFError):
375+
logger.error(
376+
"Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
377+
)
378+
sys.exit(1)
379+
380+
def verify_no_extra_args(self, args: argparse.Namespace) -> None:
381+
"""
382+
Verifies that only --stdin was provided when using stdin mode.
383+
"""
384+
if not args.stdin:
385+
return
386+
387+
subparser = none_throws(self._subparser)
388+
conflicting_args = []
389+
390+
# Check each argument against its default value
391+
for action in subparser._actions:
392+
if action.dest == "stdin": # Skip stdin itself
393+
continue
394+
if action.dest == "help": # Skip help
395+
continue
396+
397+
current_value = getattr(args, action.dest, None)
398+
default_value = action.default
399+
400+
# For arguments that differ from default
401+
if current_value != default_value:
402+
# Handle special cases where non-default doesn't mean explicitly set
403+
if action.dest == "component_name_and_args" and current_value == []:
404+
continue # Empty list is still default
405+
406+
conflicting_args.append(f"--{action.dest.replace('_', '-')}")
407+
408+
if conflicting_args:
409+
subparser.error(
410+
f"Cannot specify {', '.join(conflicting_args)} when using --stdin. "
411+
"All configuration should be provided in JSON input."
412+
)
413+
414+
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
415+
# Verify no conflicting arguments when using to loop over the stdin
416+
self.verify_no_extra_args(args)
417+
418+
if args.stdin:
419+
stdin_data_json = self.torchx_json_from_stdin()
420+
self._run_from_stdin_args(runner, stdin_data_json)
421+
else:
422+
self._run_from_cli_args(runner, args)
423+
338424
def run(self, args: argparse.Namespace) -> None:
339425
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
340426
component_defaults = load_sections(prefix="component")

torchx/cli/test/cmd_run_test.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import argparse
1111
import dataclasses
1212
import io
13-
1413
import os
1514
import shutil
1615
import signal
@@ -67,14 +66,12 @@ def tearDown(self) -> None:
6766
torchxconfig.called_args = set()
6867

6968
def test_run_with_multiple_scheduler_args(self) -> None:
70-
7169
args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"]
7270
with self.assertRaises(SystemExit) as cm:
7371
self.parser.parse_args(args)
7472
self.assertEqual(cm.exception.code, 1)
7573

7674
def test_run_with_multiple_schedule_args(self) -> None:
77-
7875
args = [
7976
"--scheduler",
8077
"local_cwd",
@@ -179,13 +176,13 @@ def test_conf_file_missing(self) -> None:
179176
with patch(
180177
"torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir]
181178
):
179+
args = self.parser.parse_args(
180+
[
181+
"--scheduler",
182+
"local_cwd",
183+
]
184+
)
182185
with self.assertRaises(SystemExit):
183-
args = self.parser.parse_args(
184-
[
185-
"--scheduler",
186-
"local_cwd",
187-
]
188-
)
189186
self.cmd_run.run(args)
190187

191188
@patch("torchx.runner.Runner.run")
@@ -364,6 +361,81 @@ def test_parse_component_name_and_args_with_default(self) -> None:
364361
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
365362
)
366363

364+
def test_verify_no_extra_args_stdin_only(self) -> None:
365+
"""Test that only --stdin is allowed when using stdin mode."""
366+
args = self.parser.parse_args(["--stdin"])
367+
# Should not raise any exception
368+
self.cmd_run.verify_no_extra_args(args)
369+
370+
def test_verify_no_extra_args_no_stdin(self) -> None:
371+
"""Test that verification is skipped when not using stdin."""
372+
args = self.parser.parse_args(["--scheduler", "local_cwd", "utils.echo"])
373+
# Should not raise any exception
374+
self.cmd_run.verify_no_extra_args(args)
375+
376+
def test_verify_no_extra_args_stdin_with_component_name(self) -> None:
377+
"""Test that component name conflicts with stdin."""
378+
args = self.parser.parse_args(["--stdin", "utils.echo"])
379+
with self.assertRaises(SystemExit):
380+
self.cmd_run.verify_no_extra_args(args)
381+
382+
def test_verify_no_extra_args_stdin_with_scheduler_args(self) -> None:
383+
"""Test that scheduler_args conflicts with stdin."""
384+
args = self.parser.parse_args(["--stdin", "--scheduler_args", "cluster=test"])
385+
with self.assertRaises(SystemExit):
386+
self.cmd_run.verify_no_extra_args(args)
387+
388+
def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:
389+
"""Test that non-default scheduler conflicts with stdin."""
390+
args = self.parser.parse_args(["--stdin", "--scheduler", "kubernetes"])
391+
with self.assertRaises(SystemExit):
392+
self.cmd_run.verify_no_extra_args(args)
393+
394+
def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
395+
"""Test that boolean flags conflict with stdin."""
396+
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
397+
for flag in boolean_flags:
398+
args = self.parser.parse_args(["--stdin", flag])
399+
with self.assertRaises(SystemExit):
400+
self.cmd_run.verify_no_extra_args(args)
401+
402+
def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
403+
"""Test that arguments with values conflict with stdin."""
404+
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
405+
with self.assertRaises(SystemExit):
406+
self.cmd_run.verify_no_extra_args(args)
407+
408+
args = self.parser.parse_args(["--stdin", "--parent_run_id", "experiment_123"])
409+
with self.assertRaises(SystemExit):
410+
self.cmd_run.verify_no_extra_args(args)
411+
412+
def test_verify_no_extra_args_stdin_with_multiple_conflicts(self) -> None:
413+
"""Test that multiple conflicting arguments with stdin are detected."""
414+
args = self.parser.parse_args(
415+
["--stdin", "--dryrun", "--wait", "--scheduler_args", "cluster=test"]
416+
)
417+
with self.assertRaises(SystemExit):
418+
self.cmd_run.verify_no_extra_args(args)
419+
420+
def test_verify_no_extra_args_stdin_with_default_scheduler(self) -> None:
421+
"""Test that using default scheduler with stdin doesn't conflict."""
422+
# Get the default scheduler and use it explicitly - should not conflict
423+
from torchx.schedulers import get_default_scheduler_name
424+
425+
default_scheduler = get_default_scheduler_name()
426+
427+
args = self.parser.parse_args(["--stdin", "--scheduler", default_scheduler])
428+
# Should not raise any exception since it's the same as default
429+
self.cmd_run.verify_no_extra_args(args)
430+
431+
def test_verify_no_extra_args_stdin_with_default_workspace(self) -> None:
432+
"""Test that using default workspace with stdin doesn't conflict."""
433+
default_workspace = f"file://{Path.cwd()}"
434+
435+
args = self.parser.parse_args(["--stdin", "--workspace", default_workspace])
436+
# Should not raise any exception since it's the same as default
437+
self.cmd_run.verify_no_extra_args(args)
438+
367439

368440
class CmdBuiltinTest(unittest.TestCase):
369441
def test_run(self) -> None:

torchx/runner/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def build_standalone_workspace(
179179
def run_component(
180180
self,
181181
component: str,
182-
component_args: List[str],
182+
component_args: List[str] | Dict[str, Any],
183183
scheduler: str,
184184
cfg: Optional[Mapping[str, CfgVal]] = None,
185185
workspace: Optional[str] = None,
@@ -238,7 +238,7 @@ def run_component(
238238
def dryrun_component(
239239
self,
240240
component: str,
241-
component_args: List[str],
241+
component_args: List[str] | Dict[str, Any],
242242
scheduler: str,
243243
cfg: Optional[Mapping[str, CfgVal]] = None,
244244
workspace: Optional[str] = None,
@@ -249,10 +249,13 @@ def dryrun_component(
249249
component, but just returns what "would" have run.
250250
"""
251251
component_def = get_component(component)
252+
args_from_cli = component_args if isinstance(component_args, list) else []
253+
args_from_json = component_args if isinstance(component_args, dict) else {}
252254
app = materialize_appdef(
253255
component_def.fn,
254-
component_args,
256+
args_from_cli,
255257
self._component_defaults.get(component, None),
258+
args_from_json,
256259
)
257260
return self.dryrun(
258261
app,

torchx/specs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,8 @@ def gpu_x_1() -> Dict[str, Resource]:
225225
"make_app_handle",
226226
"materialize_appdef",
227227
"parse_mounts",
228+
"torchx_run_args_from_argparse",
229+
"torchx_run_args_from_json",
230+
"TorchXRunArgs",
228231
"ALL",
229232
]

0 commit comments

Comments
 (0)