diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f4c7988bb..73bf5b766 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -32,7 +32,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer -from forge.types import LauncherConfig, ProvisionerConfig +from forge.types import LauncherConfig, ProvisionerConfig, Launcher from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig @@ -314,12 +314,12 @@ async def main(cfg: DictConfig): max_res_tokens = cfg.max_res_tokens # ---- Global setups ---- # - if cfg.get("provisioner", None) is not None: - await init_provisioner( - ProvisionerConfig( - launcher_config=LauncherConfig(**cfg.provisioner.launcher) - ) + # if cfg.get("provisioner", None) is not None: + await init_provisioner( + ProvisionerConfig( + launcher_config=LauncherConfig(launcher=Launcher("slurm")), ) + ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index ca88b349a..c22995292 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -3,7 +3,7 @@ # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability # Global configuration -group_size: 2 +group_size: 8 batch_size: 8 max_req_tokens: 512 max_res_tokens: 512 @@ -11,7 +11,8 @@ model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default provisioner: - launcher: slurm + launcher_config: + launcher: slurm # Main loop configuration rollout_threads: 1 # Recommended to set equal to policy.num_replicas @@ -37,7 +38,7 @@ dataset: policy: engine_config: model: ${model} - tensor_parallel_size: 4 + tensor_parallel_size: 8 pipeline_parallel_size: 1 enforce_eager: false sampling_config: @@ -69,8 +70,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 4 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -136,8 +137,8 @@ actors: procs: 1 with_gpus: false trainer: - procs: 8 - hosts: 1 + procs: 4 + # hosts: 1 with_gpus: true replay_buffer: procs: 1 diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index a899da6f0..624c250e6 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -166,6 +166,8 @@ async def launch(cls, *args, **kwargs) -> "ForgeActor": mesh_name=cls.mesh_name, ) + print(f"Spawning proc mesh {cls.mesh_name} with gpus {cls.with_gpus}") + proc_mesh = await get_proc_mesh(process_config=cfg) actor_name = kwargs.pop("name", cls.__name__) diff --git a/src/forge/types.py b/src/forge/types.py index 4eeac55bd..c9b1d6d20 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -159,9 +159,9 @@ class LauncherConfig: """A launcher config for the scheduler.""" launcher: Launcher - job_name: str - services: dict[str, ServiceConfig] - actors: dict[str, ProcessConfig] + job_name: str | None = None + services: dict[str, ServiceConfig] | None = None + actors: dict[str, ProcessConfig] | None = None @dataclass