Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:

processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 8
hosts: 1
procs: 8
with_gpus: true

optimizer:
Expand Down
4 changes: 2 additions & 2 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
policy_proc_config = copy(process_config)
policy_proc_config.num_procs = 1
policy_proc_config.num_hosts = None
policy_proc_config.procs = 1
policy_proc_config.hosts = None
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)
Expand Down
135 changes: 78 additions & 57 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import sys
import types
from typing import Type, TypeVar

from monarch.actor import Actor, current_rank, current_size, endpoint
Expand All @@ -21,6 +22,16 @@
T = TypeVar("T", bound="ForgeActor")


def filter_config_params(cls, kwargs: dict) -> dict:
from inspect import signature

"""
Filters kwargs to only include parameters that are valid for the given config class.
"""
sig = signature(cls)
return {k: v for k, v in kwargs.items() if k in sig.parameters}


class ForgeActor(Actor):
def __init__(self, *args, **kwargs):
if not hasattr(self, "_rank"):
Expand Down Expand Up @@ -48,68 +59,57 @@ def __init__(self, *args, **kwargs):
def options(
cls: Type[T],
*,
service_config: ServiceConfig | None = None,
num_replicas: int | None = None,
procs: int | None = None,
**service_kwargs,
procs: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also hosts: int, with_gpu: bool and num_replicas: int | None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of put them all in **kwargs since only procs is required for both service and actor. Do you think it is better to explicitly list them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please explicitly list them

**kwargs,
) -> Type[T]:
"""
Returns a subclass of this ForgeActor with a bound ServiceConfig.
The returned subclass can later be launched via `.as_service()`.

Usage (choose ONE of the following forms):
# Option A: construct ServiceConfig implicitly
service = await MyForgeActor.options(
num_replicas=1,
procs=2,
).as_service(...)
await service.shutdown()

# Option B: provide an explicit ServiceConfig
cfg = ServiceConfig(num_replicas=1, procs=2, ..)
service = await MyForgeActor.options(service_config=cfg).as_service(...)
await service.shutdown()

# Option C: skip options, use the default service config with num_replicas=1, procs=1
service = await MyForgeActor.as_service(...)
await service.shutdown()
Returns a dynamically created subclass of this ForgeActor with bound configuration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns a dynamically created subclass of this ForgeActor with bound configuration.
Returns a version of ForgeActor with configured resource attributes.


This method allows you to pre-configure an actor class before spawning it with
`.as_actor()` or `.as_service()`. Each call creates a separate subclass, so
multiple different configurations can coexist without interfering with each other.

---- Usage Examples ----

# Pre-configure a service with multiple replicas
service = await MyForgeActor.options(num_replicas=2, procs=2).as_service(...)
await service.shutdown()

# Default usage without calling options
service = await MyForgeActor.as_service(...)
await service.shutdown()

# Pre-configure a single actor
actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...)
await actor.shutdown()

# Default usage without calling options
actor = await MyForgeActor.as_actor(...)
await actor.shutdown()
"""

if service_config is not None:
cfg = service_config
else:
if num_replicas is None or procs is None:
raise ValueError(
"Must provide either `service_config` or (num_replicas + procs)."
)
cfg = ServiceConfig(
num_replicas=num_replicas,
procs=procs,
**service_kwargs,
)

return type(
f"{cls.__name__}Service",
(cls,),
{"_service_config": cfg},
)
cfg_dict = {"procs": procs, **kwargs}
return type(cls.__name__, (cls,), cfg_dict)

@classmethod
async def as_service(cls: Type[T], **actor_kwargs) -> "ServiceInterface":
"""
Convenience method to spawn this actor as a Service using default configuration.
If `.options()` was called, it will use the bound ServiceConfig;
otherwise defaults to 1 replica, 1 proc.
Spawns this actor as a Service using the configuration stored in `.options()`,
or defaults if `.options()` was not called.

The configuration values stored in the subclass returned by `.options()` (like
`procs` and `num_replicas`) are used to construct a ServiceConfig instance.
If no configuration was stored, defaults to a single replica with one process.
"""
# Lazy import to avoid top-level dependency issues
from forge.controller.service import Service, ServiceInterface

# Use _service_config if already set by options(), else default
cfg = getattr(cls, "_service_config", None)
if cfg is None:
cfg = ServiceConfig(num_replicas=1, procs=1)
# dynamically create a configured subclass for consistency
cls = type(f"{cls.__name__}Service", (cls,), {"_service_config": cfg})
class_attrs = {k: v for k, v in cls.__dict__.items() if not k.startswith("__")}
if "procs" not in class_attrs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow up comment on explicit attributes, this for e.g. is unclear and can be pretty brittle

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in the latest version

class_attrs["procs"] = 1
if "num_replicas" not in class_attrs:
class_attrs["num_replicas"] = 1
cfg = ServiceConfig(**filter_config_params(ServiceConfig, class_attrs))

logger.info("Spawning Service Actor for %s", cls.__name__)
service = Service(cfg, cls, actor_kwargs)
Expand Down Expand Up @@ -154,7 +154,7 @@ async def set_env(self, addr: str, port: str):
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor":
async def launch(cls, **kwargs) -> "ForgeActor":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add *args here? This solves the *args related TODO that's listed here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in launch and as_actor. Also tested in test_as_actor_with_kwargs_config

"""Provisions and deploys a new actor.

This method is used by `Service` to provision a new replica.
Expand All @@ -167,7 +167,13 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
a homogeneous set of actors on a single proc mesh.

"""
proc_mesh = await get_proc_mesh(process_config=process_config)
# Build process config from class attributes with defaults
cfg = ProcessConfig(
procs=getattr(cls, "procs", 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally try and use getattr as little as possible. If it's used too much it can mask real errors that can be really hard to debug later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fallback when the user doesn’t specify configs via .options(). In this case, the original ForgeActor class doesn’t have attributes like procs. If we are getting rid of getattr, one way I can think of is to add these attributes to ForgeActor class like

class ForgeActor(Actor):
    procs: int = 1
    hosts: int | None = None
    with_gpus: bool = False
    num_replicas: int = 1

    def __init__(self, *args, **kwargs):

But either way, it means the default values are specified in three places:

  1. In types.py
  2. As default values in .options()
  3. As attributes on the ForgeActor class OR here in launch.

I’m not sure if there’s a cleaner way to handle this. I’ve updated the code accordingly (get rid of getattr), please take a look and let me know if you have any suggestions or improvements.

hosts=getattr(cls, "hosts", None),
with_gpus=getattr(cls, "with_gpus", False),
)
proc_mesh = await get_proc_mesh(process_config=cfg)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
Expand All @@ -181,11 +187,26 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
return actor

@classmethod
async def shutdown(cls, actor: "ForgeActor"):
"""Shuts down an actor.
async def as_actor(cls: Type[T], **actor_kwargs) -> T:
"""
Spawns a single actor using the configuration stored in `.options()`, or defaults.

This method is used by `Service` to teardown a replica.
The configuration values stored in the subclass returned by `.options()` (like
`procs`) are used to construct a ProcessConfig instance.
If no configuration was stored, defaults to a single process with no GPU.
"""
if actor._proc_mesh is None:
logger.info("Spawning single actor %s", cls.__name__)
actor = await cls.launch(**actor_kwargs)

# Patch shutdown to bypass endpoint system
actor.shutdown = types.MethodType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this is a hack, we shouldn't be doing this. I'm guessing it's because we want to preserve the ability to

svc = MyActor.as_service()

await svc.shutdown()

?

Copy link
Member Author

@DNXie DNXie Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as_service returns a ServiceInterface. So when we call service.shutdown(), we are actually calling ServiceInterface.shutdown

The reason I have to do this hacky thing is:
Without it, actor.shutdown() gives me this error:

RuntimeError: Actor <class 'tests.unit_tests.test_service.Counter'>.shutdown is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)

If I simply decorate shutdown with @endpoint, we'd have to call it like

await actor.shutdown.call()

But it would still give error:

AssertionError("Called shutdown on a replica with no proc_mesh.")

Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see. Ok in that case, I think what we should do is not do actor.shutdown() for now, and just rely on eg

await RLTrainer.stop(trainer)

for now. Maybe what we can do next is have the provisioner keep track of all of the proc meshes, and do a global shutdown()? Including all the services etc. we can discuss more, just want to unblock this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Done!

lambda self: self._class.shutdown(self), actor
)

return actor

async def shutdown(self):
"""Stop this actor safely without going through endpoint system."""
if getattr(self, "_proc_mesh", None) is None:
raise AssertionError("Called shutdown on a replica with no proc_mesh.")
await stop_proc_mesh(actor._proc_mesh)
await stop_proc_mesh(self._proc_mesh)
4 changes: 2 additions & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def _get_provisioner():

async def get_proc_mesh(config: ProcessConfig) -> ProcMesh:
return await _get_provisioner().get_proc_mesh(
num_procs=config.num_procs,
num_procs=config.procs,
with_gpus=config.with_gpus,
num_hosts=config.num_hosts,
num_hosts=config.hosts,
)


Expand Down
9 changes: 4 additions & 5 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Optional

Expand Down Expand Up @@ -157,11 +157,10 @@ async def initialize(self):
try:
# Deploy the actor and its underlying resources
logger.debug(f"Launching actor for replica {self.idx}")
self.actor = await self.actor_def.launch(
process_config=self.proc_config,
**self.actor_kwargs,
)

self.actor = await self.actor_def.options(
**asdict(self.proc_config)
).as_actor(**self.actor_kwargs)
# Transition to healthy state and start processing
self.state = ReplicaState.HEALTHY
self.start_processing()
Expand Down
1 change: 1 addition & 0 deletions src/forge/data_models/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Sequence

import torch

from forge.data_models.scored_completion import ScoredCompletion


Expand Down
1 change: 0 additions & 1 deletion src/forge/data_models/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any


class Role(Enum):
Expand Down
10 changes: 5 additions & 5 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class State:
class ProcessConfig:
"""A proc_mesh config for the torchx scheduler."""

num_procs: int = 1
procs: int = 1
with_gpus: bool = False
num_hosts: int | None = None
hosts: int | None = None


@dataclass
Expand Down Expand Up @@ -121,12 +121,12 @@ class ServiceConfig:

def to_process_config(self) -> ProcessConfig:
"""Extract ProcessConfig from this ServiceConfig.
Maps procs to num_procs for ProcessConfig.
Maps procs to procs for ProcessConfig.
"""
return ProcessConfig(
num_procs=self.procs,
procs=self.procs,
with_gpus=self.with_gpus,
num_hosts=self.hosts,
hosts=self.hosts,
)


Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import asdict
from typing import Callable

import pytest
Expand All @@ -17,7 +18,6 @@
from forge.actors.trainer import RLTrainer
from forge.controller.service import ServiceConfig
from forge.data.sharding import VLLMSharding

from transformers import AutoModelForCausalLM

requires_cuda = pytest.mark.skipif(
Expand Down Expand Up @@ -262,7 +262,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
policy_config, service_config = get_configs(
worker_size=worker_size, tp_size=worker_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
Expand Down Expand Up @@ -302,7 +302,7 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp):
policy_config, service_config = get_configs(
worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model
)
policy = await Policy.options(service_config=service_config).as_service(
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
Expand Down
8 changes: 3 additions & 5 deletions tests/unit_tests/test_provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import pytest
from forge.controller.provisioner import GpuManager, Provisioner
from forge.types import ProcessConfig


class TestGpuManagerCudaVisibleDevices:
Expand Down Expand Up @@ -158,11 +157,10 @@ async def test_get_proc_mesh_respects_cuda_visible_devices(self):

# Note - this can run even on CPU because with_gpus just sets environment
# variables.
config = ProcessConfig(num_procs=2, with_gpus=True, num_hosts=None)
_ = await provisioner.get_proc_mesh(
num_procs=config.num_procs,
with_gpus=config.with_gpus,
num_hosts=config.num_hosts,
num_procs=2,
with_gpus=True,
num_hosts=None,
)
# Verify GPUs were allocated from available set
remaining_available = local_gpu_manager.get_available_gpus()
Expand Down
Loading
Loading