Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,126 @@ def ddp(
)


def monarch(
*script_args: str,
script: Optional[str] = None,
m: Optional[str] = None,
image: str = "torchx:latest",
name: str = "/",
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
j: str = "1x2",
env: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, str]] = None,
max_retries: int = 0,
rdzv_port: int = 29500,
mounts: Optional[List[str]] = None,
debug: bool = False,
tee: int = 3,
) -> specs.AppDef:
"""
Single-Program-Multiple-Data (SPMD) style application using Monarch actors.

Launches multiple copies of the same training script using Monarch's actor system
for process management instead of torch.distributed.run. Monarch actors handle
distributed environment setup (RANK, WORLD_SIZE, LOCAL_RANK, MASTER_ADDR, MASTER_PORT)
and execute the training script.

This component can be used for any distributed PyTorch training pattern including
DDP, FSDP, or RPC-based training.

Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
``h`` takes precedence if specified for setting resource requirements.

Args:
script_args: arguments to the main module
script: script or binary to run within the image
m: the python module path to run
image: image (e.g. docker)
name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
Uses the script or module name if ``{runname}`` not specified.
cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
max_retries: the number of scheduler retries allowed
rdzv_port: the port used for master_port in distributed setup (default 29500)
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
tee: accepted for API compatibility with ddp(), but not used by monarch
"""

if (script is None) == (m is None):
raise ValueError("exactly one of --script and -m must be specified")

min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j)

env = env or {}
metadata = metadata or {}

argname = StructuredNameArgument.parse_from(name=name, m=m, script=script)
env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name
env["TORCHX_TRACKING_RUN_NAME"] = argname.run_name
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
if debug:
env.update(_TORCH_DEBUG_FLAGS)

# Determine master_addr based on single vs multi-node
if max_nnodes == 1:
master_addr = "localhost"
else:
master_addr = _noquote(f"$${{{macros.rank0_env}:=localhost}}")

cmd = [
"python",
"-m",
"monarch.actor.torchrun",
"--nproc_per_node",
str(nproc_per_node),
"--nnodes",
str(max_nnodes),
"--node_idx",
f"{macros.replica_id}",
"--master_addr",
master_addr,
"--master_port",
str(rdzv_port),
]

if script is not None:
cmd += [script]
elif m is not None:
cmd += ["-m", m]

cmd += script_args

return specs.AppDef(
name=argname.run_name,
roles=[
specs.Role(
name=get_role_name(script, m),
image=image,
min_replicas=min_nnodes,
entrypoint="bash",
num_replicas=int(max_nnodes),
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
args=["-c", _args_join(cmd)],
env=env,
port_map={"c10d": rdzv_port},
max_retries=max_retries,
mounts=specs.parse_mounts(mounts) if mounts else [],
)
],
metadata=metadata,
)


def get_role_name(script: Optional[str], m: Optional[str]) -> str:
if script:
# script name/module no extension
Expand Down
Loading