Skip to content

Commit 448e08e

Browse files
authored
Support JSON input to torchx run
Differential Revision: D80731827 Pull Request resolved: #1107
1 parent 33fca61 commit 448e08e

File tree

5 files changed

+223
-33
lines changed

5 files changed

+223
-33
lines changed

torchx/cli/cmd_run.py

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import argparse
10+
import json
1011
import logging
1112
import os
1213
import sys
@@ -41,6 +42,12 @@
4142
"missing component name, either provide it from the CLI or in .torchxconfig"
4243
)
4344

45+
LOCAL_SCHEDULER_WARNING_MSG = (
46+
"`local` scheduler is deprecated and will be"
47+
" removed in the near future,"
48+
" please use other variants of the local scheduler"
49+
" (e.g. `local_cwd`)"
50+
)
4451

4552
logger: logging.Logger = logging.getLogger(__name__)
4653

@@ -54,7 +61,7 @@ class TorchXRunArgs:
5461
dryrun: bool = False
5562
wait: bool = False
5663
log: bool = False
57-
workspace: str = f"file://{Path.cwd()}"
64+
workspace: str = ""
5865
parent_run_id: Optional[str] = None
5966
tee_logs: bool = False
6067
component_args: Dict[str, Any] = field(default_factory=dict)
@@ -83,7 +90,10 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
8390
"Please check your JSON and try launching again.",
8491
)
8592

86-
return TorchXRunArgs(**filtered_json_data)
93+
torchx_args = TorchXRunArgs(**filtered_json_data)
94+
if torchx_args.workspace == "":
95+
torchx_args.workspace = f"file://{Path.cwd()}"
96+
return torchx_args
8797

8898

8999
def torchx_run_args_from_argparse(
@@ -256,35 +266,35 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
256266
default=False,
257267
help="Add additional prefix to log lines to indicate which replica is printing the log",
258268
)
269+
subparser.add_argument(
270+
"--stdin",
271+
action="store_true",
272+
default=False,
273+
help="Read JSON input from stdin to parse into torchx run args and run the component.",
274+
)
259275
subparser.add_argument(
260276
"component_name_and_args",
261277
nargs=argparse.REMAINDER,
262278
)
263279

264-
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
280+
def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
265281
if args.scheduler == "local":
266-
logger.warning(
267-
"`local` scheduler is deprecated and will be"
268-
" removed in the near future,"
269-
" please use other variants of the local scheduler"
270-
" (e.g. `local_cwd`)"
271-
)
272-
273-
cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args))
274-
config.apply(scheduler=args.scheduler, cfg=cfg)
282+
logger.warning(LOCAL_SCHEDULER_WARNING_MSG)
275283

276-
component, component_args = _parse_component_name_and_args(
277-
args.component_name_and_args,
278-
none_throws(self._subparser),
284+
config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg)
285+
component_args = (
286+
args.component_args_str
287+
if args.component_args_str != []
288+
else args.component_args
279289
)
280290
try:
281291
if args.dryrun:
282292
dryrun_info = runner.dryrun_component(
283-
component,
293+
args.component_name,
284294
component_args,
285295
args.scheduler,
286296
workspace=args.workspace,
287-
cfg=cfg,
297+
cfg=args.scheduler_cfg,
288298
parent_run_id=args.parent_run_id,
289299
)
290300
print(
@@ -295,11 +305,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
295305
print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}")
296306
else:
297307
app_handle = runner.run_component(
298-
component,
308+
args.component_name,
299309
component_args,
300310
args.scheduler,
301311
workspace=args.workspace,
302-
cfg=cfg,
312+
cfg=args.scheduler_cfg,
303313
parent_run_id=args.parent_run_id,
304314
)
305315
# DO NOT delete this line. It is used by slurm tests to retrieve the app id
@@ -320,7 +330,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
320330
)
321331

322332
except (ComponentValidationException, ComponentNotFoundException) as e:
323-
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
333+
error_msg = (
334+
f"\nFailed to run component `{args.component_name}` got errors: \n {e}"
335+
)
324336
logger.error(error_msg)
325337
sys.exit(1)
326338
except specs.InvalidRunConfigException as e:
@@ -335,6 +347,86 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
335347
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
336348
sys.exit(1)
337349

350+
def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None:
351+
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
352+
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
353+
354+
component, component_args = _parse_component_name_and_args(
355+
args.component_name_and_args,
356+
none_throws(self._subparser),
357+
)
358+
torchx_run_args = torchx_run_args_from_argparse(
359+
args, component, component_args, cfg
360+
)
361+
self._run_inner(runner, torchx_run_args)
362+
363+
def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None:
364+
torchx_run_args = torchx_run_args_from_json(stdin_data)
365+
scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler)
366+
cfg = scheduler_opts.cfg_from_json_repr(
367+
json.dumps(torchx_run_args.scheduler_args)
368+
)
369+
torchx_run_args.scheduler_cfg = cfg
370+
self._run_inner(runner, torchx_run_args)
371+
372+
def torchx_json_from_stdin(self) -> Dict[str, Any]:
373+
try:
374+
stdin_data_json = json.load(sys.stdin)
375+
if not isinstance(stdin_data_json, dict):
376+
logger.error(
377+
"Invalid JSON input for `torchx run` command. Expected a dictionary."
378+
)
379+
sys.exit(1)
380+
return stdin_data_json
381+
except (json.JSONDecodeError, EOFError):
382+
logger.error(
383+
"Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
384+
)
385+
sys.exit(1)
386+
387+
def verify_no_extra_args(self, args: argparse.Namespace) -> None:
388+
"""
389+
Verifies that only --stdin was provided when using stdin mode.
390+
"""
391+
if not args.stdin:
392+
return
393+
394+
subparser = none_throws(self._subparser)
395+
conflicting_args = []
396+
397+
# Check each argument against its default value
398+
for action in subparser._actions:
399+
if action.dest == "stdin": # Skip stdin itself
400+
continue
401+
if action.dest == "help": # Skip help
402+
continue
403+
404+
current_value = getattr(args, action.dest, None)
405+
default_value = action.default
406+
407+
# For arguments that differ from default
408+
if current_value != default_value:
409+
# Handle special cases where non-default doesn't mean explicitly set
410+
if action.dest == "component_name_and_args" and current_value == []:
411+
continue # Empty list is still default
412+
print(f"*********\n {default_value} = {current_value}")
413+
conflicting_args.append(f"--{action.dest.replace('_', '-')}")
414+
415+
if conflicting_args:
416+
subparser.error(
417+
f"Cannot specify {', '.join(conflicting_args)} when using --stdin. "
418+
"All configuration should be provided in JSON input."
419+
)
420+
421+
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
422+
# Verify no conflicting arguments when using to loop over the stdin
423+
self.verify_no_extra_args(args)
424+
if args.stdin:
425+
stdin_data_json = self.torchx_json_from_stdin()
426+
self._run_from_stdin_args(runner, stdin_data_json)
427+
else:
428+
self._run_from_cli_args(runner, args)
429+
338430
def run(self, args: argparse.Namespace) -> None:
339431
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
340432
component_defaults = load_sections(prefix="component")

torchx/cli/test/cmd_run_test.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import argparse
1111
import dataclasses
1212
import io
13-
1413
import os
1514
import shutil
1615
import signal
@@ -67,14 +66,12 @@ def tearDown(self) -> None:
6766
torchxconfig.called_args = set()
6867

6968
def test_run_with_multiple_scheduler_args(self) -> None:
70-
7169
args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"]
7270
with self.assertRaises(SystemExit) as cm:
7371
self.parser.parse_args(args)
7472
self.assertEqual(cm.exception.code, 1)
7573

7674
def test_run_with_multiple_schedule_args(self) -> None:
77-
7875
args = [
7976
"--scheduler",
8077
"local_cwd",
@@ -179,13 +176,13 @@ def test_conf_file_missing(self) -> None:
179176
with patch(
180177
"torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir]
181178
):
179+
args = self.parser.parse_args(
180+
[
181+
"--scheduler",
182+
"local_cwd",
183+
]
184+
)
182185
with self.assertRaises(SystemExit):
183-
args = self.parser.parse_args(
184-
[
185-
"--scheduler",
186-
"local_cwd",
187-
]
188-
)
189186
self.cmd_run.run(args)
190187

191188
@patch("torchx.runner.Runner.run")
@@ -364,6 +361,96 @@ def test_parse_component_name_and_args_with_default(self) -> None:
364361
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
365362
)
366363

364+
def test_verify_no_extra_args_stdin_only(self) -> None:
365+
"""Test that only --stdin is allowed when using stdin mode."""
366+
args = self.parser.parse_args(["--stdin"])
367+
# Should not raise any exception
368+
self.cmd_run.verify_no_extra_args(args)
369+
370+
def test_verify_no_extra_args_no_stdin(self) -> None:
371+
"""Test that verification is skipped when not using stdin."""
372+
args = self.parser.parse_args(["--scheduler", "local_cwd", "utils.echo"])
373+
# Should not raise any exception
374+
self.cmd_run.verify_no_extra_args(args)
375+
376+
def test_verify_no_extra_args_stdin_with_component_name(self) -> None:
377+
"""Test that component name conflicts with stdin."""
378+
args = self.parser.parse_args(["--stdin", "utils.echo"])
379+
with self.assertRaises(SystemExit):
380+
self.cmd_run.verify_no_extra_args(args)
381+
382+
def test_verify_no_extra_args_stdin_with_scheduler_args(self) -> None:
383+
"""Test that scheduler_args conflicts with stdin."""
384+
args = self.parser.parse_args(["--stdin", "--scheduler_args", "cluster=test"])
385+
with self.assertRaises(SystemExit):
386+
self.cmd_run.verify_no_extra_args(args)
387+
388+
def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:
389+
"""Test that non-default scheduler conflicts with stdin."""
390+
args = self.parser.parse_args(["--stdin", "--scheduler", "kubernetes"])
391+
with self.assertRaises(SystemExit):
392+
self.cmd_run.verify_no_extra_args(args)
393+
394+
def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
395+
"""Test that boolean flags conflict with stdin."""
396+
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
397+
for flag in boolean_flags:
398+
args = self.parser.parse_args(["--stdin", flag])
399+
with self.assertRaises(SystemExit):
400+
self.cmd_run.verify_no_extra_args(args)
401+
402+
def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
403+
"""Test that arguments with values conflict with stdin."""
404+
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
405+
with self.assertRaises(SystemExit):
406+
self.cmd_run.verify_no_extra_args(args)
407+
408+
args = self.parser.parse_args(["--stdin", "--parent_run_id", "experiment_123"])
409+
with self.assertRaises(SystemExit):
410+
self.cmd_run.verify_no_extra_args(args)
411+
412+
def test_verify_no_extra_args_stdin_with_multiple_conflicts(self) -> None:
413+
"""Test that multiple conflicting arguments with stdin are detected."""
414+
args = self.parser.parse_args(
415+
["--stdin", "--dryrun", "--wait", "--scheduler_args", "cluster=test"]
416+
)
417+
with self.assertRaises(SystemExit):
418+
self.cmd_run.verify_no_extra_args(args)
419+
420+
def test_verify_no_extra_args_stdin_with_default_scheduler(self) -> None:
421+
"""Test that using default scheduler with stdin doesn't conflict."""
422+
# Get the default scheduler and use it explicitly - should not conflict
423+
from torchx.schedulers import get_default_scheduler_name
424+
425+
default_scheduler = get_default_scheduler_name()
426+
427+
args = self.parser.parse_args(["--stdin", "--scheduler", default_scheduler])
428+
# Should not raise any exception since it's the same as default
429+
self.cmd_run.verify_no_extra_args(args)
430+
431+
def test_verify_no_extra_args_stdin_with_default_workspace(self) -> None:
432+
"""Test that using default workspace with stdin doesn't conflict."""
433+
# Get the actual default workspace from a fresh parser
434+
fresh_parser = argparse.ArgumentParser()
435+
fresh_cmd_run = CmdRun()
436+
fresh_cmd_run.add_arguments(fresh_parser)
437+
438+
# Find the workspace argument's default value
439+
workspace_default = None
440+
for action in fresh_parser._actions:
441+
if action.dest == "workspace":
442+
workspace_default = action.default
443+
break
444+
445+
self.assertIsNotNone(
446+
workspace_default, "workspace argument should have a default"
447+
)
448+
449+
# Use the actual default - this should not conflict with stdin
450+
args = fresh_parser.parse_args(["--stdin", "--workspace", workspace_default])
451+
# Should not raise any exception since it's the same as default
452+
fresh_cmd_run.verify_no_extra_args(args)
453+
367454

368455
class CmdBuiltinTest(unittest.TestCase):
369456
def test_run(self) -> None:

torchx/runner/api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Type,
2626
TYPE_CHECKING,
2727
TypeVar,
28+
Union,
2829
)
2930

3031
from torchx.runner.events import log_event
@@ -167,7 +168,7 @@ def close(self) -> None:
167168
def run_component(
168169
self,
169170
component: str,
170-
component_args: List[str],
171+
component_args: Union[list[str], dict[str, Any]],
171172
scheduler: str,
172173
cfg: Optional[Mapping[str, CfgVal]] = None,
173174
workspace: Optional[str] = None,
@@ -226,7 +227,7 @@ def run_component(
226227
def dryrun_component(
227228
self,
228229
component: str,
229-
component_args: List[str],
230+
component_args: Union[list[str], dict[str, Any]],
230231
scheduler: str,
231232
cfg: Optional[Mapping[str, CfgVal]] = None,
232233
workspace: Optional[str] = None,
@@ -237,10 +238,13 @@ def dryrun_component(
237238
component, but just returns what "would" have run.
238239
"""
239240
component_def = get_component(component)
241+
args_from_cli = component_args if isinstance(component_args, list) else []
242+
args_from_json = component_args if isinstance(component_args, dict) else {}
240243
app = materialize_appdef(
241244
component_def.fn,
242-
component_args,
245+
args_from_cli,
243246
self._component_defaults.get(component, None),
247+
args_from_json,
244248
)
245249
return self.dryrun(
246250
app,

torchx/specs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,5 +225,8 @@ def gpu_x_1() -> Dict[str, Resource]:
225225
"make_app_handle",
226226
"materialize_appdef",
227227
"parse_mounts",
228+
"torchx_run_args_from_argparse",
229+
"torchx_run_args_from_json",
230+
"TorchXRunArgs",
228231
"ALL",
229232
]

0 commit comments

Comments
 (0)