diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index dd774203c..126df3527 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -7,6 +7,7 @@ # pyre-strict import argparse +import json import logging import os import sys @@ -41,6 +42,12 @@ "missing component name, either provide it from the CLI or in .torchxconfig" ) +LOCAL_SCHEDULER_WARNING_MSG = ( + "`local` scheduler is deprecated and will be" + " removed in the near future," + " please use other variants of the local scheduler" + " (e.g. `local_cwd`)" +) logger: logging.Logger = logging.getLogger(__name__) @@ -54,7 +61,7 @@ class TorchXRunArgs: dryrun: bool = False wait: bool = False log: bool = False - workspace: str = f"file://{Path.cwd()}" + workspace: str = "" parent_run_id: Optional[str] = None tee_logs: bool = False component_args: Dict[str, Any] = field(default_factory=dict) @@ -83,7 +90,10 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs: "Please check your JSON and try launching again.", ) - return TorchXRunArgs(**filtered_json_data) + torchx_args = TorchXRunArgs(**filtered_json_data) + if torchx_args.workspace == "": + torchx_args.workspace = f"file://{Path.cwd()}" + return torchx_args def torchx_run_args_from_argparse( @@ -256,35 +266,35 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None: default=False, help="Add additional prefix to log lines to indicate which replica is printing the log", ) + subparser.add_argument( + "--stdin", + action="store_true", + default=False, + help="Read JSON input from stdin to parse into torchx run args and run the component.", + ) subparser.add_argument( "component_name_and_args", nargs=argparse.REMAINDER, ) - def _run(self, runner: Runner, args: argparse.Namespace) -> None: + def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None: if args.scheduler == "local": - logger.warning( - "`local` scheduler is deprecated and will be" - " removed in the near future," - " please use other variants of the local scheduler" - " (e.g. `local_cwd`)" - ) - - cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args)) - config.apply(scheduler=args.scheduler, cfg=cfg) + logger.warning(LOCAL_SCHEDULER_WARNING_MSG) - component, component_args = _parse_component_name_and_args( - args.component_name_and_args, - none_throws(self._subparser), + config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg) + component_args = ( + args.component_args_str + if args.component_args_str != [] + else args.component_args ) try: if args.dryrun: dryrun_info = runner.dryrun_component( - component, + args.component_name, component_args, args.scheduler, workspace=args.workspace, - cfg=cfg, + cfg=args.scheduler_cfg, parent_run_id=args.parent_run_id, ) print( @@ -295,11 +305,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}") else: app_handle = runner.run_component( - component, + args.component_name, component_args, args.scheduler, workspace=args.workspace, - cfg=cfg, + cfg=args.scheduler_cfg, parent_run_id=args.parent_run_id, ) # DO NOT delete this line. It is used by slurm tests to retrieve the app id @@ -320,7 +330,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: ) except (ComponentValidationException, ComponentNotFoundException) as e: - error_msg = f"\nFailed to run component `{component}` got errors: \n {e}" + error_msg = ( + f"\nFailed to run component `{args.component_name}` got errors: \n {e}" + ) logger.error(error_msg) sys.exit(1) except specs.InvalidRunConfigException as e: @@ -335,6 +347,86 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr) sys.exit(1) + def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None: + scheduler_opts = runner.scheduler_run_opts(args.scheduler) + cfg = scheduler_opts.cfg_from_str(args.scheduler_args) + + component, component_args = _parse_component_name_and_args( + args.component_name_and_args, + none_throws(self._subparser), + ) + torchx_run_args = torchx_run_args_from_argparse( + args, component, component_args, cfg + ) + self._run_inner(runner, torchx_run_args) + + def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None: + torchx_run_args = torchx_run_args_from_json(stdin_data) + scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler) + cfg = scheduler_opts.cfg_from_json_repr( + json.dumps(torchx_run_args.scheduler_args) + ) + torchx_run_args.scheduler_cfg = cfg + self._run_inner(runner, torchx_run_args) + + def torchx_json_from_stdin(self) -> Dict[str, Any]: + try: + stdin_data_json = json.load(sys.stdin) + if not isinstance(stdin_data_json, dict): + logger.error( + "Invalid JSON input for `torchx run` command. Expected a dictionary." + ) + sys.exit(1) + return stdin_data_json + except (json.JSONDecodeError, EOFError): + logger.error( + "Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input." + ) + sys.exit(1) + + def verify_no_extra_args(self, args: argparse.Namespace) -> None: + """ + Verifies that only --stdin was provided when using stdin mode. + """ + if not args.stdin: + return + + subparser = none_throws(self._subparser) + conflicting_args = [] + + # Check each argument against its default value + for action in subparser._actions: + if action.dest == "stdin": # Skip stdin itself + continue + if action.dest == "help": # Skip help + continue + + current_value = getattr(args, action.dest, None) + default_value = action.default + + # For arguments that differ from default + if current_value != default_value: + # Handle special cases where non-default doesn't mean explicitly set + if action.dest == "component_name_and_args" and current_value == []: + continue # Empty list is still default + print(f"*********\n {default_value} = {current_value}") + conflicting_args.append(f"--{action.dest.replace('_', '-')}") + + if conflicting_args: + subparser.error( + f"Cannot specify {', '.join(conflicting_args)} when using --stdin. " + "All configuration should be provided in JSON input." + ) + + def _run(self, runner: Runner, args: argparse.Namespace) -> None: + # Verify no conflicting arguments when using to loop over the stdin + self.verify_no_extra_args(args) + if args.stdin: + stdin_data_json = self.torchx_json_from_stdin() + self._run_from_stdin_args(runner, stdin_data_json) + else: + self._run_from_cli_args(runner, args) + def run(self, args: argparse.Namespace) -> None: os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run") component_defaults = load_sections(prefix="component") diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index ed64f8fee..4cdadaa58 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -10,7 +10,6 @@ import argparse import dataclasses import io - import os import shutil import signal @@ -67,14 +66,12 @@ def tearDown(self) -> None: torchxconfig.called_args = set() def test_run_with_multiple_scheduler_args(self) -> None: - args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"] with self.assertRaises(SystemExit) as cm: self.parser.parse_args(args) self.assertEqual(cm.exception.code, 1) def test_run_with_multiple_schedule_args(self) -> None: - args = [ "--scheduler", "local_cwd", @@ -179,13 +176,13 @@ def test_conf_file_missing(self) -> None: with patch( "torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir] ): + args = self.parser.parse_args( + [ + "--scheduler", + "local_cwd", + ] + ) with self.assertRaises(SystemExit): - args = self.parser.parse_args( - [ - "--scheduler", - "local_cwd", - ] - ) self.cmd_run.run(args) @patch("torchx.runner.Runner.run") @@ -364,6 +361,96 @@ def test_parse_component_name_and_args_with_default(self) -> None: _parse_component_name_and_args(["-m", "hello"], sp, dirs), ) + def test_verify_no_extra_args_stdin_only(self) -> None: + """Test that only --stdin is allowed when using stdin mode.""" + args = self.parser.parse_args(["--stdin"]) + # Should not raise any exception + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_no_stdin(self) -> None: + """Test that verification is skipped when not using stdin.""" + args = self.parser.parse_args(["--scheduler", "local_cwd", "utils.echo"]) + # Should not raise any exception + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_component_name(self) -> None: + """Test that component name conflicts with stdin.""" + args = self.parser.parse_args(["--stdin", "utils.echo"]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_scheduler_args(self) -> None: + """Test that scheduler_args conflicts with stdin.""" + args = self.parser.parse_args(["--stdin", "--scheduler_args", "cluster=test"]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_scheduler(self) -> None: + """Test that non-default scheduler conflicts with stdin.""" + args = self.parser.parse_args(["--stdin", "--scheduler", "kubernetes"]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None: + """Test that boolean flags conflict with stdin.""" + boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"] + for flag in boolean_flags: + args = self.parser.parse_args(["--stdin", flag]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_value_args(self) -> None: + """Test that arguments with values conflict with stdin.""" + args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + args = self.parser.parse_args(["--stdin", "--parent_run_id", "experiment_123"]) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_multiple_conflicts(self) -> None: + """Test that multiple conflicting arguments with stdin are detected.""" + args = self.parser.parse_args( + ["--stdin", "--dryrun", "--wait", "--scheduler_args", "cluster=test"] + ) + with self.assertRaises(SystemExit): + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_default_scheduler(self) -> None: + """Test that using default scheduler with stdin doesn't conflict.""" + # Get the default scheduler and use it explicitly - should not conflict + from torchx.schedulers import get_default_scheduler_name + + default_scheduler = get_default_scheduler_name() + + args = self.parser.parse_args(["--stdin", "--scheduler", default_scheduler]) + # Should not raise any exception since it's the same as default + self.cmd_run.verify_no_extra_args(args) + + def test_verify_no_extra_args_stdin_with_default_workspace(self) -> None: + """Test that using default workspace with stdin doesn't conflict.""" + # Get the actual default workspace from a fresh parser + fresh_parser = argparse.ArgumentParser() + fresh_cmd_run = CmdRun() + fresh_cmd_run.add_arguments(fresh_parser) + + # Find the workspace argument's default value + workspace_default = None + for action in fresh_parser._actions: + if action.dest == "workspace": + workspace_default = action.default + break + + self.assertIsNotNone( + workspace_default, "workspace argument should have a default" + ) + + # Use the actual default - this should not conflict with stdin + args = fresh_parser.parse_args(["--stdin", "--workspace", workspace_default]) + # Should not raise any exception since it's the same as default + fresh_cmd_run.verify_no_extra_args(args) + class CmdBuiltinTest(unittest.TestCase): def test_run(self) -> None: diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 825bb9747..f38367ecd 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -25,6 +25,7 @@ Type, TYPE_CHECKING, TypeVar, + Union, ) from torchx.runner.events import log_event @@ -167,7 +168,7 @@ def close(self) -> None: def run_component( self, component: str, - component_args: List[str], + component_args: Union[list[str], dict[str, Any]], scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, workspace: Optional[str] = None, @@ -226,7 +227,7 @@ def run_component( def dryrun_component( self, component: str, - component_args: List[str], + component_args: Union[list[str], dict[str, Any]], scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, workspace: Optional[str] = None, @@ -237,10 +238,13 @@ def dryrun_component( component, but just returns what "would" have run. """ component_def = get_component(component) + args_from_cli = component_args if isinstance(component_args, list) else [] + args_from_json = component_args if isinstance(component_args, dict) else {} app = materialize_appdef( component_def.fn, - component_args, + args_from_cli, self._component_defaults.get(component, None), + args_from_json, ) return self.dryrun( app, diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index 40b0d1202..ae1622e61 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -225,5 +225,8 @@ def gpu_x_1() -> Dict[str, Resource]: "make_app_handle", "materialize_appdef", "parse_mounts", + "torchx_run_args_from_argparse", + "torchx_run_args_from_json", + "TorchXRunArgs", "ALL", ] diff --git a/torchx/specs/builders.py b/torchx/specs/builders.py index 4f7c3af25..126518854 100644 --- a/torchx/specs/builders.py +++ b/torchx/specs/builders.py @@ -213,7 +213,11 @@ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef: arg_value = getattr(parsed_args, param_name) parameter_type = parameter.annotation parameter_type = decode_optional(parameter_type) - arg_value = decode(arg_value, parameter_type) + if ( + parameter_type != arg_value.__class__ + and parameter.kind != inspect.Parameter.VAR_POSITIONAL + ): + arg_value = decode(arg_value, parameter_type) if parameter.kind == inspect.Parameter.VAR_POSITIONAL: var_args = arg_value elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: