Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ async def main(cfg: DictConfig):
# ---- Global setups ---- #
if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(**cfg.provisioner.launcher)
)
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
Expand Down
4 changes: 2 additions & 2 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def create_server_handle(self) -> str:
def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None:
if not cfg:
return None
if cfg.launcher == Launcher.SLURM:
if cfg.launcher == Launcher.SLURM.value:
return Slurmlauncher()
elif cfg.launcher == Launcher.MAST:
elif cfg.launcher == Launcher.MAST.value:
if not _MAST_AVAILABLE:
raise ValueError(
"MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation"
Expand Down
3 changes: 3 additions & 0 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def bootstrap(env: dict[str, str]):
env_vars["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
env_vars["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"

# Shows detailed logs for Monarch rust failures
env_vars["RUST_BACKTRACE"] = "1"

procs = host_mesh.spawn_procs(
per_host={"gpus": num_procs},
bootstrap=functools.partial(bootstrap, env=env_vars),
Expand Down
6 changes: 3 additions & 3 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ class LauncherConfig:
"""A launcher config for the scheduler."""

launcher: Launcher
job_name: str
services: dict[str, ServiceConfig]
actors: dict[str, ProcessConfig]
job_name: str = ""
services: dict[str, ServiceConfig] = field(default_factory=dict)
actors: dict[str, ProcessConfig] = field(default_factory=dict)


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions tests/sandbox/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ policy:
guided_decoding: false
max_tokens: 512

provisioner:
launcher: slurm

services:
policy:
procs: 8
Expand Down
8 changes: 7 additions & 1 deletion tests/sandbox/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@

from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.provisioner import shutdown

from forge.controller.provisioner import init_provisioner, shutdown

from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.types import LauncherConfig, ProvisionerConfig
from omegaconf import DictConfig

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


async def run(cfg: DictConfig):
if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
Expand Down
Loading