6262
6363import torchx
6464import torchx .specs as specs
65+ from torchx .components .structured_arg import StructuredJArgument , StructuredNameArgument
6566from torchx .specs import macros
6667
6768_TORCH_DEBUG_FLAGS : Dict [str , str ] = {
8081"""
8182
8283
84+ def spmd (
85+ * args : str ,
86+ script : Optional [str ] = None ,
87+ m : Optional [str ] = None ,
88+ image : str = torchx .IMAGE ,
89+ name : str = "/" ,
90+ h : str = "gpu.small" ,
91+ j : str = "1x1" ,
92+ env : Optional [Dict [str , str ]] = None ,
93+ max_retries : int = 0 ,
94+ mounts : Optional [List [str ]] = None ,
95+ debug : bool = False ,
96+ ) -> specs .AppDef :
97+ """
98+ Usage (by script): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 --script path/to/my/trainer.py -foo bar
99+
100+ Usage (by module): torchx run spmd -j 2x8 -h aws_p4d.24xlarge --name my_experiment/trial_1 -m path.to.my.trainer -foo bar
101+
102+ Usage (infer GPU count): torchx run spmd -j 2 -h p4d.24xlarge ... (same as -j 2x8)
103+
104+ Creates a torchx.specs.AppDef (Job Definition) for a Single-Process-Multiple-Data (SPMD)
105+ style application. See: https://en.wikipedia.org/wiki/Single_program,_multiple_data.
106+
107+ SPMD launches `n x m` (set via the `-j nxm` option) copies of the same program,
108+ where `n` is the number of nodes (hosts) and `m` is the number of processes on each node.
109+
110+ If you have a distributed PyTorch script (DDP, FSDP, RPC) use this component to launch
111+ the distributed application. You can also use `-j 1x1` to launch a single process application
112+ which would be equivalent to launching with regular `python` except that your application
113+ can safely call `torch.distributed.init_process_group(backend)`.
114+
115+ Note: For multi-node distributed runs, the hosts MUST have a network route to each other
116+ AND port 29500 should be open on all hosts. Please check your security group settings.
117+
118+
119+ Args:
120+ args: the arguments to the main module or script (e.g. my/trainer.py -foo bar)
121+ (for docker based runs) the script path must be relative to the WORKDIR of the image
122+ script:
123+ m: the main module name (e.g. my.module.trainer). When this option is used, the `script_args` are passed
124+ as the arguments to the main module). Invoking my module is useful when the relative/absolute path
125+ of the main script is unknown w.r.t the WORKDIR of the image. Use this option when it makes sense to
126+ invoke the main script via `python -m <MAIN.MODULE>`.
127+ image: the base docker image of the workspace, if workspace is disabled, then the image of the job
128+ name: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``
129+ h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
130+ j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
131+ env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
132+ max_retries: the number of scheduler retries allowed
133+ rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
134+ Only takes effect when running multi-node. When running single node, this parameter
135+ is ignored and a random free port is chosen.
136+ mounts: (for docker based runs only) mounts to mount into the worker environment/container
137+ (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
138+ debug: whether to run with preset debug flags enabled
139+
140+ """
141+
142+ if env is None :
143+ env = {}
144+
145+ return ddp (
146+ * args ,
147+ script = script ,
148+ m = m ,
149+ image = image ,
150+ name = name ,
151+ h = h ,
152+ j = str (StructuredJArgument .parse_from (h , j )),
153+ env = env ,
154+ max_retries = max_retries ,
155+ mounts = mounts ,
156+ debug = debug ,
157+ )
158+
159+
83160def ddp (
84161 * script_args : str ,
85162 script : Optional [str ] = None ,
86163 m : Optional [str ] = None ,
87164 image : str = torchx .IMAGE ,
88- name : Optional [ str ] = None ,
165+ name : str = "/" ,
89166 h : Optional [str ] = None ,
90167 cpu : int = 2 ,
91168 gpu : int = 0 ,
@@ -114,7 +191,8 @@ def ddp(
114191 script: script or binary to run within the image
115192 m: the python module path to run
116193 image: image (e.g. docker)
117- name: job name override (uses the script name if not specified)
194+ name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
195+ Uses the script or module name if ``{runname}`` not specified.
118196 cpu: number of cpus per replica
119197 gpu: number of gpus per replica
120198 memMB: cpu memory in MB per replica
@@ -138,14 +216,6 @@ def ddp(
138216 # nproc_per_node: number of processes on each node
139217 min_nnodes , max_nnodes , nproc_per_node , nnodes_rep = parse_nnodes (j )
140218
141- if script :
142- # script name/module no extension
143- role_name = Path (script ).stem
144- elif m :
145- role_name = m .rpartition ("." )[2 ]
146- else :
147- raise ValueError ("failed to compute role_name" )
148-
149219 rdzv_backend = "c10d"
150220 if max_nnodes == 1 :
151221 # using port 0 makes elastic chose a free random port which is ok
@@ -165,8 +235,16 @@ def ddp(
165235
166236 if env is None :
167237 env = {}
168- env .setdefault ("LOGLEVEL" , os .getenv ("LOGLEVEL" , "WARNING" ))
169238
239+ argname = StructuredNameArgument .parse_from (
240+ name = name ,
241+ m = m ,
242+ script = script ,
243+ )
244+
245+ env ["TORCHX_TRACKING_EXPERIMENT_NAME" ] = argname .experiment_name
246+
247+ env .setdefault ("LOGLEVEL" , os .getenv ("LOGLEVEL" , "WARNING" ))
170248 if debug :
171249 env .update (_TORCH_DEBUG_FLAGS )
172250
@@ -193,10 +271,10 @@ def ddp(
193271 cmd += ["-m" , m ]
194272 cmd += script_args
195273 return specs .AppDef (
196- name = name or role_name ,
274+ name = argname . run_name ,
197275 roles = [
198276 specs .Role (
199- name = role_name ,
277+ name = get_role_name ( script , m ) ,
200278 image = image ,
201279 min_replicas = min_nnodes ,
202280 entrypoint = "bash" ,
@@ -214,6 +292,17 @@ def ddp(
214292 )
215293
216294
295+ def get_role_name (script : Optional [str ], m : Optional [str ]) -> str :
296+ if script :
297+ # script name/module no extension
298+ role_name = Path (script ).stem
299+ elif m :
300+ role_name = m .rpartition ("." )[2 ]
301+ else :
302+ raise ValueError ("failed to compute role_name" )
303+ return role_name
304+
305+
217306def _args_join (args : Iterable [str ]) -> str :
218307 """
219308 _args_join is like shlex.join but if the argument is wrapped in _noquote
0 commit comments