Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,13 @@
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import (
Launcher,
LauncherConfig,
ProcessConfig,
ProvisionerConfig,
ServiceConfig,
)
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -320,25 +313,20 @@ async def main(cfg: DictConfig):
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens

# init provisioner
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.SLURM.value)),
job_name=cfg.get(JOB_NAME_KEY, None),
services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
# ---- Global setups ---- #
if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(**cfg.provisioner.launcher)
)
)
)

# initialize before spawning services
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

# ---- Setup services ---- #
await ts.initialize(strategy=ts.ControllerStorageVolumes())

(
dataloader,
policy,
Expand Down
3 changes: 3 additions & 0 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ max_res_tokens: 512
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

Expand Down
3 changes: 0 additions & 3 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,12 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc_config.procs = 1
policy_proc_config.hosts = None
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)

if isinstance(engine_config, Mapping):
engine_config = EngineConfig.from_dict(engine_config)

vllm_config = engine_config.create_vllm_config()
# TODO (felipemello): LocalFetcherActor doesnt spawn with this, so cannot
# do logging within PolicyWorker
workers = worker_procs.spawn(
"vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp
)
Expand Down
17 changes: 15 additions & 2 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .proc_mesh import get_proc_mesh, stop_proc_mesh
from .provisioner import (
get_proc_mesh,
host_mesh_from_proc,
init_provisioner,
shutdown,
stop_proc_mesh,
)

__all__ = ["stop_proc_mesh", "get_proc_mesh", "ForgeActor"]
__all__ = [
"ForgeActor",
"get_proc_mesh",
"stop_proc_mesh",
"init_provisioner",
"shutdown",
"host_mesh_from_proc",
]
28 changes: 1 addition & 27 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from monarch.actor import Actor, current_rank, current_size, endpoint

from forge.controller.proc_mesh import get_proc_mesh, stop_proc_mesh
from forge.controller.provisioner import get_proc_mesh, stop_proc_mesh

from forge.types import ProcessConfig, ServiceConfig

Expand Down Expand Up @@ -144,28 +144,6 @@ async def setup(self):
"""
pass

@endpoint
async def set_env(self, addr: str, port: str):
"""A temporary workaround to set master addr/port.

TODO - issues/144. This should be done in proc_mesh creation.
The ideal path:
- Create a host mesh
- Grab a host from host mesh, from proc 0 spawn an actor that
gets addr/port
- Spawn procs on the HostMesh with addr/port, setting the
addr/port in bootstrap.

We can't currently do this because HostMesh only supports single
proc_mesh creation at the moment. This will be possible once
we have "proper HostMesh support".

"""
import os

os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, *args, **kwargs) -> "ForgeActor":
"""Provisions and deploys a new actor.
Expand Down Expand Up @@ -193,10 +171,6 @@ async def launch(cls, *args, **kwargs) -> "ForgeActor":
actor_name = kwargs.pop("name", cls.__name__)
actor = proc_mesh.spawn(actor_name, cls, *args, **kwargs)
actor._proc_mesh = proc_mesh

if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"):
host, port = proc_mesh._hostname, proc_mesh._port
await actor.set_env.call(addr=host, port=port)
await actor.setup.call()
return actor

Expand Down
43 changes: 20 additions & 23 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import getpass
import os
import socket
import subprocess

import tempfile
import uuid
from typing import Any

Expand All @@ -25,35 +26,25 @@

from forge.types import Launcher, LauncherConfig

_MAST_AVAILABLE = False

try:
from monarch._src.actor.actor_mesh import current_rank
from monarch._src.actor.meta.allocator import MastAllocator, MastAllocatorConfig
from monarch.tools.components.meta import hyperactor as meta_hyperactor
from torchx.specs import AppState
from torchx.specs.fb.component_helpers import Packages

_MAST_AVAILABLE = True
except ImportError as e:
print(f"Warning: Monarch meta/fb inetrnal imports failed: {e}")
print("Monarch functionality will be limited")
# This means there is an error with MAST
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it silently fail?

Copy link
Contributor

@Ritesh1905 Ritesh1905 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, yes. (However some sort of logging would be idea.) until we figure out what would be a right way to segregate meta-only code blocks. These dependencies are meta internal and are installed via an internal fbpkg monarch build. The env setup for mast requires a separate installation script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just added in a block that says "if imports failed and you're trying to use MAST, print an error messaging say imports failed and that you should check your build was correct"

pass

JOB_NAME_KEY = "job_name"
LAUNCHER_KEY = "launcher"


def _get_port() -> str:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", 0))
addr = s.getsockname()
port = addr[1]
return str(port)


class SetupActor(Actor):
@endpoint
def get_info(self) -> [str, str]:
return socket.gethostname(), _get_port()


class MastSetupActor(SetupActor):
class MastSetupActor(Actor):
@endpoint
def mount(self, mount_dst: str):
point = current_rank()
Expand Down Expand Up @@ -138,11 +129,12 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
role.resource.cpu = 128
role.resource.gpu = 8

# TODO - multi scheduler support
# Note - we cannot add in an empty workspace, so we create a fake temporary one
temp_workspace = tempfile.mkdtemp(prefix="forge_workspace_")
server_config = Config(
scheduler="slurm",
appdef=appdef,
workspace=monarch.tools.config.workspace.Workspace(dirs=[""]),
workspace=monarch.tools.config.workspace.Workspace(dirs=[temp_workspace]),
)
server_info = await commands.get_or_create(
"forge_job",
Expand All @@ -157,8 +149,7 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
return alloc, None, server_name # (Allocator, AllocConstraints, SeverName)

async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]:
setup = procs.spawn(f"setup-{uuid.uuid1()}", SetupActor)
return await setup.get_info.choose()
return


class Mastlauncher(BaseLauncher):
Expand Down Expand Up @@ -306,9 +297,15 @@ def create_server_handle(self) -> str:


def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None:
if not cfg or cfg.launcher == Launcher.SLURM:
if not cfg:
return None
if cfg.launcher == Launcher.SLURM:
return Slurmlauncher()
elif cfg.launcher == Launcher.MAST:
if not _MAST_AVAILABLE:
raise ValueError(
"MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation"
)
return Mastlauncher(cfg)
else:
raise ValueError(f"Unsupported config provided, got {cfg}")
30 changes: 0 additions & 30 deletions src/forge/controller/proc_mesh.py

This file was deleted.

Loading
Loading