77# pyre-strict
88
99import argparse
10+ import json
1011import logging
1112import os
1213import sys
4142 "missing component name, either provide it from the CLI or in .torchxconfig"
4243)
4344
45+ LOCAL_SCHEDULER_WARNING_MSG = (
46+ "`local` scheduler is deprecated and will be"
47+ " removed in the near future,"
48+ " please use other variants of the local scheduler"
49+ " (e.g. `local_cwd`)"
50+ )
4451
4552logger : logging .Logger = logging .getLogger (__name__ )
4653
@@ -54,7 +61,7 @@ class TorchXRunArgs:
5461 dryrun : bool = False
5562 wait : bool = False
5663 log : bool = False
57- workspace : str = f"file:// { Path . cwd () } "
64+ workspace : str = " "
5865 parent_run_id : Optional [str ] = None
5966 tee_logs : bool = False
6067 component_args : Dict [str , Any ] = field (default_factory = dict )
@@ -83,7 +90,10 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
8390 "Please check your JSON and try launching again." ,
8491 )
8592
86- return TorchXRunArgs (** filtered_json_data )
93+ torchx_args = TorchXRunArgs (** filtered_json_data )
94+ if torchx_args .workspace == "" :
95+ torchx_args .workspace = f"file://{ Path .cwd ()} "
96+ return torchx_args
8797
8898
8999def torchx_run_args_from_argparse (
@@ -256,35 +266,35 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
256266 default = False ,
257267 help = "Add additional prefix to log lines to indicate which replica is printing the log" ,
258268 )
269+ subparser .add_argument (
270+ "--stdin" ,
271+ action = "store_true" ,
272+ default = False ,
273+ help = "Read JSON input from stdin to parse into torchx run args and run the component." ,
274+ )
259275 subparser .add_argument (
260276 "component_name_and_args" ,
261277 nargs = argparse .REMAINDER ,
262278 )
263279
264- def _run (self , runner : Runner , args : argparse . Namespace ) -> None :
280+ def _run_inner (self , runner : Runner , args : TorchXRunArgs ) -> None :
265281 if args .scheduler == "local" :
266- logger .warning (
267- "`local` scheduler is deprecated and will be"
268- " removed in the near future,"
269- " please use other variants of the local scheduler"
270- " (e.g. `local_cwd`)"
271- )
272-
273- cfg = dict (runner .cfg_from_str (args .scheduler , args .scheduler_args ))
274- config .apply (scheduler = args .scheduler , cfg = cfg )
282+ logger .warning (LOCAL_SCHEDULER_WARNING_MSG )
275283
276- component , component_args = _parse_component_name_and_args (
277- args .component_name_and_args ,
278- none_throws (self ._subparser ),
284+ config .apply (scheduler = args .scheduler , cfg = args .scheduler_cfg )
285+ component_args = (
286+ args .component_args_str
287+ if args .component_args_str != []
288+ else args .component_args
279289 )
280290 try :
281291 if args .dryrun :
282292 dryrun_info = runner .dryrun_component (
283- component ,
293+ args . component_name ,
284294 component_args ,
285295 args .scheduler ,
286296 workspace = args .workspace ,
287- cfg = cfg ,
297+ cfg = args . scheduler_cfg ,
288298 parent_run_id = args .parent_run_id ,
289299 )
290300 print (
@@ -295,11 +305,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
295305 print ("\n === SCHEDULER REQUEST ===\n " f"{ dryrun_info } " )
296306 else :
297307 app_handle = runner .run_component (
298- component ,
308+ args . component_name ,
299309 component_args ,
300310 args .scheduler ,
301311 workspace = args .workspace ,
302- cfg = cfg ,
312+ cfg = args . scheduler_cfg ,
303313 parent_run_id = args .parent_run_id ,
304314 )
305315 # DO NOT delete this line. It is used by slurm tests to retrieve the app id
@@ -320,7 +330,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
320330 )
321331
322332 except (ComponentValidationException , ComponentNotFoundException ) as e :
323- error_msg = f"\n Failed to run component `{ component } ` got errors: \n { e } "
333+ error_msg = (
334+ f"\n Failed to run component `{ args .component_name } ` got errors: \n { e } "
335+ )
324336 logger .error (error_msg )
325337 sys .exit (1 )
326338 except specs .InvalidRunConfigException as e :
@@ -335,6 +347,86 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
335347 print (error_msg % (e , args .scheduler , args .scheduler ), file = sys .stderr )
336348 sys .exit (1 )
337349
350+ def _run_from_cli_args (self , runner : Runner , args : argparse .Namespace ) -> None :
351+ scheduler_opts = runner .scheduler_run_opts (args .scheduler )
352+ cfg = scheduler_opts .cfg_from_str (args .scheduler_args )
353+
354+ component , component_args = _parse_component_name_and_args (
355+ args .component_name_and_args ,
356+ none_throws (self ._subparser ),
357+ )
358+ torchx_run_args = torchx_run_args_from_argparse (
359+ args , component , component_args , cfg
360+ )
361+ self ._run_inner (runner , torchx_run_args )
362+
363+ def _run_from_stdin_args (self , runner : Runner , stdin_data : Dict [str , Any ]) -> None :
364+ torchx_run_args = torchx_run_args_from_json (stdin_data )
365+ scheduler_opts = runner .scheduler_run_opts (torchx_run_args .scheduler )
366+ cfg = scheduler_opts .cfg_from_json_repr (
367+ json .dumps (torchx_run_args .scheduler_args )
368+ )
369+ torchx_run_args .scheduler_cfg = cfg
370+ self ._run_inner (runner , torchx_run_args )
371+
372+ def torchx_json_from_stdin (self ) -> Dict [str , Any ]:
373+ try :
374+ stdin_data_json = json .load (sys .stdin )
375+ if not isinstance (stdin_data_json , dict ):
376+ logger .error (
377+ "Invalid JSON input for `torchx run` command. Expected a dictionary."
378+ )
379+ sys .exit (1 )
380+ return stdin_data_json
381+ except (json .JSONDecodeError , EOFError ):
382+ logger .error (
383+ "Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
384+ )
385+ sys .exit (1 )
386+
387+ def verify_no_extra_args (self , args : argparse .Namespace ) -> None :
388+ """
389+ Verifies that only --stdin was provided when using stdin mode.
390+ """
391+ if not args .stdin :
392+ return
393+
394+ subparser = none_throws (self ._subparser )
395+ conflicting_args = []
396+
397+ # Check each argument against its default value
398+ for action in subparser ._actions :
399+ if action .dest == "stdin" : # Skip stdin itself
400+ continue
401+ if action .dest == "help" : # Skip help
402+ continue
403+
404+ current_value = getattr (args , action .dest , None )
405+ default_value = action .default
406+
407+ # For arguments that differ from default
408+ if current_value != default_value :
409+ # Handle special cases where non-default doesn't mean explicitly set
410+ if action .dest == "component_name_and_args" and current_value == []:
411+ continue # Empty list is still default
412+ print (f"*********\n { default_value } = { current_value } " )
413+ conflicting_args .append (f"--{ action .dest .replace ('_' , '-' )} " )
414+
415+ if conflicting_args :
416+ subparser .error (
417+ f"Cannot specify { ', ' .join (conflicting_args )} when using --stdin. "
418+ "All configuration should be provided in JSON input."
419+ )
420+
421+ def _run (self , runner : Runner , args : argparse .Namespace ) -> None :
422+ # Verify no conflicting arguments when using to loop over the stdin
423+ self .verify_no_extra_args (args )
424+ if args .stdin :
425+ stdin_data_json = self .torchx_json_from_stdin ()
426+ self ._run_from_stdin_args (runner , stdin_data_json )
427+ else :
428+ self ._run_from_cli_args (runner , args )
429+
338430 def run (self , args : argparse .Namespace ) -> None :
339431 os .environ ["TORCHX_CONTEXT_NAME" ] = os .getenv ("TORCHX_CONTEXT_NAME" , "cli_run" )
340432 component_defaults = load_sections (prefix = "component" )
0 commit comments