diff --git a/torchx/components/dist.py b/torchx/components/dist.py index 495a61db3..15adb6415 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -308,6 +308,126 @@ def ddp( ) +def monarch( + *script_args: str, + script: Optional[str] = None, + m: Optional[str] = None, + image: str = "torchx:latest", + name: str = "/", + h: Optional[str] = None, + cpu: int = 2, + gpu: int = 0, + memMB: int = 1024, + j: str = "1x2", + env: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, str]] = None, + max_retries: int = 0, + rdzv_port: int = 29500, + mounts: Optional[List[str]] = None, + debug: bool = False, + tee: int = 3, +) -> specs.AppDef: + """ + Single-Program-Multiple-Data (SPMD) style application using Monarch actors. + + Launches multiple copies of the same training script using Monarch's actor system + for process management instead of torch.distributed.run. Monarch actors handle + distributed environment setup (RANK, WORLD_SIZE, LOCAL_RANK, MASTER_ADDR, MASTER_PORT) + and execute the training script. + + This component can be used for any distributed PyTorch training pattern including + DDP, FSDP, or RPC-based training. + + Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where + ``h`` takes precedence if specified for setting resource requirements. + + Args: + script_args: arguments to the main module + script: script or binary to run within the image + m: the python module path to run + image: image (e.g. docker) + name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``. + Uses the script or module name if ``{runname}`` not specified. + cpu: number of cpus per replica + gpu: number of gpus per replica + memMB: cpu memory in MB per replica + h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) + j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus + env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) + metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3) + max_retries: the number of scheduler retries allowed + rdzv_port: the port used for master_port in distributed setup (default 29500) + mounts: mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). + See scheduler documentation for more info. + debug: whether to run with preset debug flags enabled + tee: accepted for API compatibility with ddp(), but not used by monarch + """ + + if (script is None) == (m is None): + raise ValueError("exactly one of --script and -m must be specified") + + min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j) + + env = env or {} + metadata = metadata or {} + + argname = StructuredNameArgument.parse_from(name=name, m=m, script=script) + env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name + env["TORCHX_TRACKING_RUN_NAME"] = argname.run_name + env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING")) + if debug: + env.update(_TORCH_DEBUG_FLAGS) + + # Determine master_addr based on single vs multi-node + if max_nnodes == 1: + master_addr = "localhost" + else: + master_addr = _noquote(f"$${{{macros.rank0_env}:=localhost}}") + + cmd = [ + "python", + "-m", + "monarch.actor.torchrun", + "--nproc_per_node", + str(nproc_per_node), + "--nnodes", + str(max_nnodes), + "--node_idx", + f"{macros.replica_id}", + "--master_addr", + master_addr, + "--master_port", + str(rdzv_port), + ] + + if script is not None: + cmd += [script] + elif m is not None: + cmd += ["-m", m] + + cmd += script_args + + return specs.AppDef( + name=argname.run_name, + roles=[ + specs.Role( + name=get_role_name(script, m), + image=image, + min_replicas=min_nnodes, + entrypoint="bash", + num_replicas=int(max_nnodes), + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), + args=["-c", _args_join(cmd)], + env=env, + port_map={"c10d": rdzv_port}, + max_retries=max_retries, + mounts=specs.parse_mounts(mounts) if mounts else [], + ) + ], + metadata=metadata, + ) + + def get_role_name(script: Optional[str], m: Optional[str]) -> str: if script: # script name/module no extension