diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f4c7988bb..c64f00bc2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -316,9 +316,7 @@ async def main(cfg: DictConfig): # ---- Global setups ---- # if cfg.get("provisioner", None) is not None: await init_provisioner( - ProvisionerConfig( - launcher_config=LauncherConfig(**cfg.provisioner.launcher) - ) + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger() diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index f2fe5f0f2..f75114114 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -299,9 +299,9 @@ def create_server_handle(self) -> str: def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None: if not cfg: return None - if cfg.launcher == Launcher.SLURM: + if cfg.launcher == Launcher.SLURM.value: return Slurmlauncher() - elif cfg.launcher == Launcher.MAST: + elif cfg.launcher == Launcher.MAST.value: if not _MAST_AVAILABLE: raise ValueError( "MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation" diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 35ab7e525..13c19ac50 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -232,6 +232,9 @@ def bootstrap(env: dict[str, str]): env_vars["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" env_vars["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" + # Shows detailed logs for Monarch rust failures + env_vars["RUST_BACKTRACE"] = "1" + procs = host_mesh.spawn_procs( per_host={"gpus": num_procs}, bootstrap=functools.partial(bootstrap, env=env_vars), diff --git a/src/forge/types.py b/src/forge/types.py index 4eeac55bd..45312db79 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -159,9 +159,9 @@ class LauncherConfig: """A launcher config for the scheduler.""" launcher: Launcher - job_name: str - services: dict[str, ServiceConfig] - actors: dict[str, ProcessConfig] + job_name: str = "" + services: dict[str, ServiceConfig] = field(default_factory=dict) + actors: dict[str, ProcessConfig] = field(default_factory=dict) @dataclass diff --git a/tests/sandbox/vllm/deepseek_r1.yaml b/tests/sandbox/vllm/deepseek_r1.yaml index 252b20a3f..2255a5c03 100644 --- a/tests/sandbox/vllm/deepseek_r1.yaml +++ b/tests/sandbox/vllm/deepseek_r1.yaml @@ -13,6 +13,9 @@ policy: guided_decoding: false max_tokens: 512 +provisioner: + launcher: slurm + services: policy: procs: 8 diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index b425af324..0f3ce662c 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -15,10 +15,12 @@ from forge.actors.policy import Policy from forge.cli.config import parse -from forge.controller.provisioner import shutdown + +from forge.controller.provisioner import init_provisioner, shutdown from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger +from forge.types import LauncherConfig, ProvisionerConfig from omegaconf import DictConfig os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" @@ -26,6 +28,10 @@ async def run(cfg: DictConfig): + if cfg.get("provisioner", None) is not None: + await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg)