Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def run(self, args: argparse.Namespace) -> None:
class CmdRun(SubCommand):
def __init__(self) -> None:
self._subparser: Optional[argparse.ArgumentParser] = None
self._stdin_data_json: Optional[Dict[str, Any]] = None

def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
scheduler_names = get_scheduler_factories().keys()
Expand Down Expand Up @@ -369,6 +370,15 @@ def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> No
torchx_run_args.scheduler_cfg = cfg
self._run_inner(runner, torchx_run_args)

def _get_torchx_stdin_args(
self, args: argparse.Namespace
) -> Optional[Dict[str, Any]]:
if not args.stdin:
return None
if self._stdin_data_json is None:
self._stdin_data_json = self.torchx_json_from_stdin()
return self._stdin_data_json

def torchx_json_from_stdin(self) -> Dict[str, Any]:
try:
stdin_data_json = json.load(sys.stdin)
Expand Down Expand Up @@ -419,11 +429,11 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None:
)

def _run(self, runner: Runner, args: argparse.Namespace) -> None:
# Verify no conflicting arguments when using to loop over the stdin
self.verify_no_extra_args(args)
if args.stdin:
stdin_data_json = self.torchx_json_from_stdin()
self._run_from_stdin_args(runner, stdin_data_json)
stdin_data_json = self._get_torchx_stdin_args(args)
if stdin_data_json is not None:
self._run_from_stdin_args(runner, stdin_data_json)
else:
self._run_from_cli_args(runner, args)

Expand Down
Loading