77# pyre-strict
88
99import argparse
10+ import json
1011import logging
1112import os
1213import sys
4142 "missing component name, either provide it from the CLI or in .torchxconfig"
4243)
4344
45+ LOCAL_SCHEDULER_WARNING_MSG = (
46+ "`local` scheduler is deprecated and will be"
47+ " removed in the near future,"
48+ " please use other variants of the local scheduler"
49+ " (e.g. `local_cwd`)"
50+ )
4451
4552logger : logging .Logger = logging .getLogger (__name__ )
4653
@@ -256,35 +263,31 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
256263 default = False ,
257264 help = "Add additional prefix to log lines to indicate which replica is printing the log" ,
258265 )
266+ subparser .add_argument (
267+ "--stdin" ,
268+ action = "store_true" ,
269+ default = False ,
270+ help = "Read JSON input from stdin to parse into torchx run args and run the component." ,
271+ )
259272 subparser .add_argument (
260273 "component_name_and_args" ,
261274 nargs = argparse .REMAINDER ,
262275 )
263276
264- def _run (self , runner : Runner , args : argparse . Namespace ) -> None :
277+ def _run_inner (self , runner : Runner , args : TorchXRunArgs ) -> None :
265278 if args .scheduler == "local" :
266- logger .warning (
267- "`local` scheduler is deprecated and will be"
268- " removed in the near future,"
269- " please use other variants of the local scheduler"
270- " (e.g. `local_cwd`)"
271- )
279+ logger .warning (LOCAL_SCHEDULER_WARNING_MSG )
272280
273- cfg = dict (runner .cfg_from_str (args .scheduler , args .scheduler_args ))
274- config .apply (scheduler = args .scheduler , cfg = cfg )
281+ config .apply (scheduler = args .scheduler , cfg = args .scheduler_cfg )
275282
276- component , component_args = _parse_component_name_and_args (
277- args .component_name_and_args ,
278- none_throws (self ._subparser ),
279- )
280283 try :
281284 if args .dryrun :
282285 dryrun_info = runner .dryrun_component (
283- component ,
284- component_args ,
286+ args . component_name ,
287+ args . component_args ,
285288 args .scheduler ,
286289 workspace = args .workspace ,
287- cfg = cfg ,
290+ cfg = args . scheduler_cfg ,
288291 parent_run_id = args .parent_run_id ,
289292 )
290293 print (
@@ -295,11 +298,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
295298 print ("\n === SCHEDULER REQUEST ===\n " f"{ dryrun_info } " )
296299 else :
297300 app_handle = runner .run_component (
298- component ,
299- component_args ,
301+ args . component_name ,
302+ args . component_args_str ,
300303 args .scheduler ,
301304 workspace = args .workspace ,
302- cfg = cfg ,
305+ cfg = args . scheduler_cfg ,
303306 parent_run_id = args .parent_run_id ,
304307 )
305308 # DO NOT delete this line. It is used by slurm tests to retrieve the app id
@@ -320,7 +323,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
320323 )
321324
322325 except (ComponentValidationException , ComponentNotFoundException ) as e :
323- error_msg = f"\n Failed to run component `{ component } ` got errors: \n { e } "
326+ error_msg = (
327+ f"\n Failed to run component `{ args .component_name } ` got errors: \n { e } "
328+ )
324329 logger .error (error_msg )
325330 sys .exit (1 )
326331 except specs .InvalidRunConfigException as e :
@@ -335,6 +340,87 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
335340 print (error_msg % (e , args .scheduler , args .scheduler ), file = sys .stderr )
336341 sys .exit (1 )
337342
343+ def _run_from_cli_args (self , runner : Runner , args : argparse .Namespace ) -> None :
344+ scheduler_opts = runner .scheduler_run_opts (args .scheduler )
345+ cfg = scheduler_opts .cfg_from_str (args .scheduler_args )
346+
347+ component , component_args = _parse_component_name_and_args (
348+ args .component_name_and_args ,
349+ none_throws (self ._subparser ),
350+ )
351+ torchx_run_args = torchx_run_args_from_argparse (
352+ args , component , component_args , cfg
353+ )
354+ self ._run_inner (runner , torchx_run_args )
355+
356+ def _run_from_stdin_args (self , runner : Runner , stdin_data : Dict [str , Any ]) -> None :
357+ torchx_run_args = torchx_run_args_from_json (stdin_data )
358+ scheduler_opts = runner .scheduler_run_opts (torchx_run_args .scheduler )
359+ cfg = scheduler_opts .cfg_from_json_repr (
360+ json .dumps (torchx_run_args .scheduler_args )
361+ )
362+ torchx_run_args .scheduler_cfg = cfg
363+ self ._run_inner (runner , torchx_run_args )
364+
365+ def torchx_json_from_stdin (self ) -> Dict [str , Any ]:
366+ try :
367+ stdin_data_json = json .load (sys .stdin )
368+ if not isinstance (stdin_data_json , dict ):
369+ logger .error (
370+ "Invalid JSON input for `torchx run` command. Expected a dictionary."
371+ )
372+ sys .exit (1 )
373+ return stdin_data_json
374+ except (json .JSONDecodeError , EOFError ):
375+ logger .error (
376+ "Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
377+ )
378+ sys .exit (1 )
379+
380+ def verify_no_extra_args (self , args : argparse .Namespace ) -> None :
381+ """
382+ Verifies that only --stdin was provided when using stdin mode.
383+ """
384+ if not args .stdin :
385+ return
386+
387+ subparser = none_throws (self ._subparser )
388+ conflicting_args = []
389+
390+ # Check each argument against its default value
391+ for action in subparser ._actions :
392+ if action .dest == "stdin" : # Skip stdin itself
393+ continue
394+ if action .dest == "help" : # Skip help
395+ continue
396+
397+ current_value = getattr (args , action .dest , None )
398+ default_value = action .default
399+
400+ # For arguments that differ from default
401+ if current_value != default_value :
402+ # Handle special cases where non-default doesn't mean explicitly set
403+ if action .dest == "component_name_and_args" and current_value == []:
404+ continue # Empty list is still default
405+
406+ conflicting_args .append (f"--{ action .dest .replace ('_' , '-' )} " )
407+
408+ if conflicting_args :
409+ subparser .error (
410+ f"Cannot specify { ', ' .join (conflicting_args )} when using --stdin. "
411+ "All configuration should be provided in JSON input."
412+ )
413+
414+ def _run (self , runner : Runner , args : argparse .Namespace ) -> None :
415+ # Verify no conflicting arguments when using to loop over the stdin
416+ self .verify_no_extra_args (args )
417+
418+ if args .stdin :
419+ stdin_data_json = self .torchx_json_from_stdin ()
420+ self ._run_from_stdin_args (runner , stdin_data_json )
421+ else :
422+ self ._run_from_cli_args (runner , args )
423+
338424 def run (self , args : argparse .Namespace ) -> None :
339425 os .environ ["TORCHX_CONTEXT_NAME" ] = os .getenv ("TORCHX_CONTEXT_NAME" , "cli_run" )
340426 component_defaults = load_sections (prefix = "component" )
0 commit comments