diff --git a/apps/mast/qwen3_14b_mast.yaml b/apps/mast/qwen3_14b_mast.yaml index 83d5b8103..484a71538 100644 --- a/apps/mast/qwen3_14b_mast.yaml +++ b/apps/mast/qwen3_14b_mast.yaml @@ -71,8 +71,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 4 - tensor_parallel_degree: 2 + data_parallel_shard_degree: 8 + tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -85,7 +85,7 @@ trainer: interval: 500 async_mode: "disabled" activation_checkpoint: - mode: selective + mode: full selective_ac_option: op comm: # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP @@ -112,7 +112,7 @@ ref_model: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 1 + data_parallel_shard_degree: 2 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 @@ -121,6 +121,10 @@ ref_model: enable: true initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 initial_load_in_hf: true + comm: + # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP + # from oilfs if the traienr is not in the same region as in oilfs + init_timeout_seconds: 1200 # All resource allocations services: @@ -131,7 +135,7 @@ services: mesh_name: policy hosts: 1 ref_model: - procs: 1 + procs: 2 num_replicas: 2 with_gpus: true mesh_name: ref_model diff --git a/apps/mast/qwen3_32b_mast.yaml b/apps/mast/qwen3_32b_mast.yaml index 0db8f4af3..47368becd 100644 --- a/apps/mast/qwen3_32b_mast.yaml +++ b/apps/mast/qwen3_32b_mast.yaml @@ -1,5 +1,5 @@ # Grouped Relative Policy Optimization (GRPO) -# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml +# >>> python -m apps.mast.main --config apps/mast/qwen3_1_7b_mast.yaml # Global configuration group_size: 8 @@ -71,8 +71,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 4 - tensor_parallel_degree: 2 + data_parallel_shard_degree: 8 + tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -85,7 +85,7 @@ trainer: interval: 500 async_mode: "disabled" activation_checkpoint: - mode: selective + mode: full selective_ac_option: op comm: # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP @@ -113,7 +113,7 @@ ref_model: parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: 1 - tensor_parallel_degree: 2 + tensor_parallel_degree: 4 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -121,7 +121,10 @@ ref_model: enable: true initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 initial_load_in_hf: true - + comm: + # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP + # from oilfs if the traienr is not in the same region as in oilfs + init_timeout_seconds: 1200 # All resource allocations services: policy: @@ -131,7 +134,7 @@ services: mesh_name: policy hosts: 1 ref_model: - procs: 1 + procs: 4 num_replicas: 2 with_gpus: true mesh_name: ref_model @@ -141,7 +144,6 @@ services: num_replicas: 1 with_gpus: false mesh_name: reward_actor - actors: dataset: procs: 1 diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index cc57e5246..7ab9f9527 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -18,6 +18,7 @@ from torchtitan.config.job_config import ( Checkpoint, + Comm, Compile, Model, Parallelism, @@ -42,6 +43,7 @@ class ReferenceModel(ForgeActor): parallelism: Parallelism = field(default_factory=Parallelism) checkpoint: Checkpoint = field(default_factory=Checkpoint) compile: Compile = field(default_factory=Compile) + comm: Comm = field(default_factory=Comm) training: Training = field( default_factory=Training ) # Needed in order to set attrs like dtype, garbage collection freq, etc. diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index f75114114..f2fe5f0f2 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -299,9 +299,9 @@ def create_server_handle(self) -> str: def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None: if not cfg: return None - if cfg.launcher == Launcher.SLURM.value: + if cfg.launcher == Launcher.SLURM: return Slurmlauncher() - elif cfg.launcher == Launcher.MAST.value: + elif cfg.launcher == Launcher.MAST: if not _MAST_AVAILABLE: raise ValueError( "MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation" diff --git a/src/forge/types.py b/src/forge/types.py index 45312db79..6a9dcc122 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -163,6 +163,10 @@ class LauncherConfig: services: dict[str, ServiceConfig] = field(default_factory=dict) actors: dict[str, ProcessConfig] = field(default_factory=dict) + def __post_init__(self): + if isinstance(self.launcher, str): + self.launcher = Launcher(self.launcher) + @dataclass class ProvisionerConfig: