Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions apps/mast/qwen3_14b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 4
tensor_parallel_degree: 2
data_parallel_shard_degree: 8
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -85,7 +85,7 @@ trainer:
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
mode: full
selective_ac_option: op
comm:
# TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
Expand All @@ -112,7 +112,7 @@ ref_model:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
data_parallel_shard_degree: 2
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
Expand All @@ -121,6 +121,10 @@ ref_model:
enable: true
initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
initial_load_in_hf: true
comm:
# TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
# from oilfs if the traienr is not in the same region as in oilfs
init_timeout_seconds: 1200

# All resource allocations
services:
Expand All @@ -131,7 +135,7 @@ services:
mesh_name: policy
hosts: 1
ref_model:
procs: 1
procs: 2
num_replicas: 2
with_gpus: true
mesh_name: ref_model
Expand Down
18 changes: 10 additions & 8 deletions apps/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
# >>> python -m apps.mast.main --config apps/mast/qwen3_1_7b_mast.yaml

# Global configuration
group_size: 8
Expand Down Expand Up @@ -71,8 +71,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 4
tensor_parallel_degree: 2
data_parallel_shard_degree: 8
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -85,7 +85,7 @@ trainer:
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
mode: full
selective_ac_option: op
comm:
# TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
Expand Down Expand Up @@ -113,15 +113,18 @@ ref_model:
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 2
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
initial_load_in_hf: true

comm:
# TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP
# from oilfs if the traienr is not in the same region as in oilfs
init_timeout_seconds: 1200
# All resource allocations
services:
policy:
Expand All @@ -131,7 +134,7 @@ services:
mesh_name: policy
hosts: 1
ref_model:
procs: 1
procs: 4
num_replicas: 2
with_gpus: true
mesh_name: ref_model
Expand All @@ -141,7 +144,6 @@ services:
num_replicas: 1
with_gpus: false
mesh_name: reward_actor

actors:
dataset:
procs: 1
Expand Down
2 changes: 2 additions & 0 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torchtitan.config.job_config import (
Checkpoint,
Comm,
Compile,
Model,
Parallelism,
Expand All @@ -42,6 +43,7 @@ class ReferenceModel(ForgeActor):
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
compile: Compile = field(default_factory=Compile)
comm: Comm = field(default_factory=Comm)
training: Training = field(
default_factory=Training
) # Needed in order to set attrs like dtype, garbage collection freq, etc.
Expand Down
4 changes: 2 additions & 2 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def create_server_handle(self) -> str:
def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None:
if not cfg:
return None
if cfg.launcher == Launcher.SLURM.value:
if cfg.launcher == Launcher.SLURM:
return Slurmlauncher()
elif cfg.launcher == Launcher.MAST.value:
elif cfg.launcher == Launcher.MAST:
if not _MAST_AVAILABLE:
raise ValueError(
"MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation"
Expand Down
4 changes: 4 additions & 0 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class LauncherConfig:
services: dict[str, ServiceConfig] = field(default_factory=dict)
actors: dict[str, ProcessConfig] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.launcher, str):
self.launcher = Launcher(self.launcher)


@dataclass
class ProvisionerConfig:
Expand Down
Loading