Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:

processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 8
hosts: 1
procs: 8
with_gpus: true

optimizer:
Expand Down
4 changes: 2 additions & 2 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
policy_proc_config = copy(process_config)
policy_proc_config.num_procs = 1
policy_proc_config.num_hosts = None
policy_proc_config.procs = 1
policy_proc_config.hosts = None
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)
Expand Down
141 changes: 87 additions & 54 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import math
import sys
from typing import Type, TypeVar
import types
from typing import Any, Type, TypeVar

from monarch.actor import Actor, current_rank, current_size, endpoint

Expand All @@ -22,6 +23,12 @@


class ForgeActor(Actor):
procs: int = 1
hosts: int | None = None
with_gpus: bool = False
num_replicas: int = 1
_extra_config: dict[str, Any] = {}

def __init__(self, *args, **kwargs):
if not hasattr(self, "_rank"):
self._rank = current_rank().rank
Expand All @@ -48,68 +55,69 @@ def __init__(self, *args, **kwargs):
def options(
cls: Type[T],
*,
service_config: ServiceConfig | None = None,
num_replicas: int | None = None,
procs: int | None = None,
**service_kwargs,
procs: int = 1,
hosts: int | None = None,
with_gpus: bool = False,
num_replicas: int = 1,
**kwargs,
) -> Type[T]:
"""
Returns a subclass of this ForgeActor with a bound ServiceConfig.
The returned subclass can later be launched via `.as_service()`.

Usage (choose ONE of the following forms):
# Option A: construct ServiceConfig implicitly
service = await MyForgeActor.options(
num_replicas=1,
procs=2,
).as_service(...)
await service.shutdown()

# Option B: provide an explicit ServiceConfig
cfg = ServiceConfig(num_replicas=1, procs=2, ..)
service = await MyForgeActor.options(service_config=cfg).as_service(...)
await service.shutdown()

# Option C: skip options, use the default service config with num_replicas=1, procs=1
service = await MyForgeActor.as_service(...)
await service.shutdown()
Returns a dynamically created subclass of this ForgeActor with bound configuration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns a dynamically created subclass of this ForgeActor with bound configuration.
Returns a version of ForgeActor with configured resource attributes.


This method allows you to pre-configure an actor class before spawning it with
`.as_actor()` or `.as_service()`. Each call creates a separate subclass, so
multiple different configurations can coexist without interfering with each other.

---- Usage Examples ----

# Pre-configure a service with multiple replicas
service = await MyForgeActor.options(num_replicas=2, procs=2).as_service(...)
await service.shutdown()

# Default usage without calling options
service = await MyForgeActor.as_service(...)
await service.shutdown()

# Pre-configure a single actor
actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...)
await actor.shutdown()

# Default usage without calling options
actor = await MyForgeActor.as_actor(...)
await actor.shutdown()
"""

if service_config is not None:
cfg = service_config
else:
if num_replicas is None or procs is None:
raise ValueError(
"Must provide either `service_config` or (num_replicas + procs)."
)
cfg = ServiceConfig(
num_replicas=num_replicas,
procs=procs,
**service_kwargs,
)

return type(
f"{cls.__name__}Service",
(cls,),
{"_service_config": cfg},
)
attrs = {
"procs": procs,
"hosts": hosts,
"with_gpus": with_gpus,
"num_replicas": num_replicas,
"_extra_config": kwargs,
}

return type(cls.__name__, (cls,), attrs)

@classmethod
async def as_service(cls: Type[T], **actor_kwargs) -> "ServiceInterface":
"""
Convenience method to spawn this actor as a Service using default configuration.
If `.options()` was called, it will use the bound ServiceConfig;
otherwise defaults to 1 replica, 1 proc.
Spawns this actor as a Service using the configuration stored in `.options()`,
or defaults if `.options()` was not called.

The configuration values stored in the subclass returned by `.options()` (like
`procs` and `num_replicas`) are used to construct a ServiceConfig instance.
If no configuration was stored, defaults to a single replica with one process.
"""
# Lazy import to avoid top-level dependency issues
from forge.controller.service import Service, ServiceInterface

# Use _service_config if already set by options(), else default
cfg = getattr(cls, "_service_config", None)
if cfg is None:
cfg = ServiceConfig(num_replicas=1, procs=1)
# dynamically create a configured subclass for consistency
cls = type(f"{cls.__name__}Service", (cls,), {"_service_config": cfg})
cfg_kwargs = {
"procs": cls.procs,
"hosts": cls.hosts,
"with_gpus": cls.with_gpus,
"num_replicas": cls.num_replicas,
**cls._extra_config, # all extra fields
}
cfg = ServiceConfig(**cfg_kwargs)

logger.info("Spawning Service Actor for %s", cls.__name__)
service = Service(cfg, cls, actor_kwargs)
Expand Down Expand Up @@ -154,7 +162,7 @@ async def set_env(self, addr: str, port: str):
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor":
async def launch(cls, **kwargs) -> "ForgeActor":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add *args here? This solves the *args related TODO that's listed here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in launch and as_actor. Also tested in test_as_actor_with_kwargs_config

"""Provisions and deploys a new actor.

This method is used by `Service` to provision a new replica.
Expand All @@ -167,7 +175,14 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
a homogeneous set of actors on a single proc mesh.

"""
proc_mesh = await get_proc_mesh(process_config=process_config)
# Build process config
cfg = ProcessConfig(
procs=cls.procs,
hosts=cls.hosts,
with_gpus=cls.with_gpus,
)

proc_mesh = await get_proc_mesh(process_config=cfg)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
Expand All @@ -180,10 +195,28 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
await actor.setup.call()
return actor

@classmethod
async def as_actor(cls: Type[T], **actor_kwargs) -> T:
"""
Spawns a single actor using the configuration stored in `.options()`, or defaults.

The configuration values stored in the subclass returned by `.options()` (like
`procs`) are used to construct a ProcessConfig instance.
If no configuration was stored, defaults to a single process with no GPU.
"""
logger.info("Spawning single actor %s", cls.__name__)
actor = await cls.launch(**actor_kwargs)

# Patch shutdown to bypass endpoint system
actor.shutdown = types.MethodType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this is a hack, we shouldn't be doing this. I'm guessing it's because we want to preserve the ability to

svc = MyActor.as_service()

await svc.shutdown()

?

Copy link
Member Author

@DNXie DNXie Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as_service returns a ServiceInterface. So when we call service.shutdown(), we are actually calling ServiceInterface.shutdown

The reason I have to do this hacky thing is:
Without it, actor.shutdown() gives me this error:

RuntimeError: Actor <class 'tests.unit_tests.test_service.Counter'>.shutdown is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)

If I simply decorate shutdown with @endpoint, we'd have to call it like

await actor.shutdown.call()

But it would still give error:

AssertionError("Called shutdown on a replica with no proc_mesh.")

Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see. Ok in that case, I think what we should do is not do actor.shutdown() for now, and just rely on eg

await RLTrainer.stop(trainer)

for now. Maybe what we can do next is have the provisioner keep track of all of the proc meshes, and do a global shutdown()? Including all the services etc. we can discuss more, just want to unblock this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Done!

lambda self: self._class.shutdown(self), actor
)

return actor

@classmethod
async def shutdown(cls, actor: "ForgeActor"):
"""Shuts down an actor.

This method is used by `Service` to teardown a replica.
"""
if actor._proc_mesh is None:
Expand Down
4 changes: 2 additions & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def _get_provisioner():

async def get_proc_mesh(config: ProcessConfig) -> ProcMesh:
return await _get_provisioner().get_proc_mesh(
num_procs=config.num_procs,
num_procs=config.procs,
with_gpus=config.with_gpus,
num_hosts=config.num_hosts,
num_hosts=config.hosts,
)


Expand Down
9 changes: 4 additions & 5 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Optional

Expand Down Expand Up @@ -157,11 +157,10 @@ async def initialize(self):
try:
# Deploy the actor and its underlying resources
logger.debug(f"Launching actor for replica {self.idx}")
self.actor = await self.actor_def.launch(
process_config=self.proc_config,
**self.actor_kwargs,
)

self.actor = await self.actor_def.options(
**asdict(self.proc_config)
).as_actor(**self.actor_kwargs)
# Transition to healthy state and start processing
self.state = ReplicaState.HEALTHY
self.start_processing()
Expand Down
1 change: 1 addition & 0 deletions src/forge/data_models/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Sequence

import torch

from forge.data_models.scored_completion import ScoredCompletion


Expand Down
1 change: 0 additions & 1 deletion src/forge/data_models/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any


class Role(Enum):
Expand Down
10 changes: 5 additions & 5 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class State:
class ProcessConfig:
"""A proc_mesh config for the torchx scheduler."""

num_procs: int = 1
procs: int = 1
with_gpus: bool = False
num_hosts: int | None = None
hosts: int | None = None


@dataclass
Expand Down Expand Up @@ -121,12 +121,12 @@ class ServiceConfig:

def to_process_config(self) -> ProcessConfig:
"""Extract ProcessConfig from this ServiceConfig.
Maps procs to num_procs for ProcessConfig.
Maps procs to procs for ProcessConfig.
"""
return ProcessConfig(
num_procs=self.procs,
procs=self.procs,
with_gpus=self.with_gpus,
num_hosts=self.hosts,
hosts=self.hosts,
)


Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import asdict
from typing import Callable

import pytest
Expand All @@ -17,7 +18,6 @@
from forge.actors.trainer import RLTrainer
from forge.controller.service import ServiceConfig
from forge.data.sharding import VLLMSharding

from transformers import AutoModelForCausalLM

requires_cuda = pytest.mark.skipif(
Expand Down Expand Up @@ -262,7 +262,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
policy_config, service_config = get_configs(
worker_size=worker_size, tp_size=worker_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
Expand Down Expand Up @@ -302,7 +302,7 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp):
policy_config, service_config = get_configs(
worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
Expand Down
8 changes: 3 additions & 5 deletions tests/unit_tests/test_provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import pytest
from forge.controller.provisioner import GpuManager, Provisioner
from forge.types import ProcessConfig


class TestGpuManagerCudaVisibleDevices:
Expand Down Expand Up @@ -158,11 +157,10 @@ async def test_get_proc_mesh_respects_cuda_visible_devices(self):

# Note - this can run even on CPU because with_gpus just sets environment
# variables.
config = ProcessConfig(num_procs=2, with_gpus=True, num_hosts=None)
_ = await provisioner.get_proc_mesh(
num_procs=config.num_procs,
with_gpus=config.with_gpus,
num_hosts=config.num_hosts,
num_procs=2,
with_gpus=True,
num_hosts=None,
)
# Verify GPUs were allocated from available set
remaining_available = local_gpu_manager.get_available_gpus()
Expand Down
Loading
Loading