diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 126df3527..23f96c8f6 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -206,6 +206,7 @@ def run(self, args: argparse.Namespace) -> None: class CmdRun(SubCommand): def __init__(self) -> None: self._subparser: Optional[argparse.ArgumentParser] = None + self._stdin_data_json: Optional[Dict[str, Any]] = None def add_arguments(self, subparser: argparse.ArgumentParser) -> None: scheduler_names = get_scheduler_factories().keys() @@ -369,6 +370,15 @@ def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> No torchx_run_args.scheduler_cfg = cfg self._run_inner(runner, torchx_run_args) + def _get_torchx_stdin_args( + self, args: argparse.Namespace + ) -> Optional[Dict[str, Any]]: + if not args.stdin: + return None + if self._stdin_data_json is None: + self._stdin_data_json = self.torchx_json_from_stdin() + return self._stdin_data_json + def torchx_json_from_stdin(self) -> Dict[str, Any]: try: stdin_data_json = json.load(sys.stdin) @@ -419,11 +429,11 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None: ) def _run(self, runner: Runner, args: argparse.Namespace) -> None: - # Verify no conflicting arguments when using to loop over the stdin self.verify_no_extra_args(args) if args.stdin: - stdin_data_json = self.torchx_json_from_stdin() - self._run_from_stdin_args(runner, stdin_data_json) + stdin_data_json = self._get_torchx_stdin_args(args) + if stdin_data_json is not None: + self._run_from_stdin_args(runner, stdin_data_json) else: self._run_from_cli_args(runner, args)