Skip to content
42 changes: 23 additions & 19 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@

import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.controller.service import (
Service,
ServiceConfig,
shutdown_service,
spawn_service,
)
from forge.data.rewards import MathReward, ThinkingReward
from forge.services.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.services.reference_service import compute_sequence_logprobs, TitanRefModel
from forge.services.replay_buffer import ReplayBuffer
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchtitan.config.job_config import Model as TitanJobModelConfig
Expand Down Expand Up @@ -49,7 +53,7 @@ def add_group(self, group: Group):
self.groups.append(group)


class Trainer(ForgeActor):
class Trainer(Service):
"""GRPO Trainer implementation for policy optimization."""

def __init__(
Expand Down Expand Up @@ -160,7 +164,7 @@ async def train_step(self, batch: list[Episode]):
return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
async def update_weights(self, policy_actor):
async def update_weights(self, policy_service):
"""Update policy model weights with trainer's current weights."""
# Time how long it takes to update weights
start_time = time.time()
Expand All @@ -176,8 +180,8 @@ async def update_weights(self, policy_actor):
for key, tensor in model_state_dict.items():
cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor

# Update the policy actor's model weights
await policy_actor.update_model_weights.choose(cpu_state_dict)
# Update the policy services's model weights
await policy_service.update_model_weights.choose(cpu_state_dict)

# Set model back to training mode
self.model.train()
Expand All @@ -187,8 +191,8 @@ async def update_weights(self, policy_actor):
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")


class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""
class RewardService(Service):
"""Reward Service that uses a list of scoring functions."""

def __init__(self, reward_functions: list[Callable]):
super().__init__()
Expand All @@ -203,7 +207,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
return total_reward


class ComputeAdvantages(ForgeActor):
class ComputeAdvantages(Service):
"""Compute advantages for GRPO using reward signals."""

def __init__(self, gamma: float = 0.99, lambda_: float = 0.95):
Expand Down Expand Up @@ -244,8 +248,8 @@ async def __call__(self, groups: list[Group]) -> list[float]:
return advantages


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""
class DatasetService(Service):
"""Service wrapper for HuggingFace dataset to provide async interface."""

def __init__(
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
Expand Down Expand Up @@ -292,11 +296,11 @@ async def main():
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
reward_service,
) = await asyncio.gather(
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
DatasetService,
path="openai/gsm8k",
config_name="main",
split="train",
Expand Down Expand Up @@ -338,7 +342,7 @@ async def main():
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
RewardActor,
RewardService,
reward_functions=[MathReward(), ThinkingReward()],
),
)
Expand Down Expand Up @@ -369,7 +373,7 @@ async def continuous_rollouts():
ref_logprobs = await ref_model.forward.choose(
request=request_tokens, response=response_tokens
)
reward = await reward_actor.evaluate_response.choose(
reward = await reward_service.evaluate_response.choose(
prompt=prompt, response=action.text, target=target
)
episode.add_group(
Expand Down Expand Up @@ -432,7 +436,7 @@ async def continuous_training():
shutdown_service(dataloader),
shutdown_service(compute_advantages),
shutdown_service(ref_model),
shutdown_service(reward_actor),
shutdown_service(reward_service),
)


Expand Down
4 changes: 2 additions & 2 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import logging
import sys

from forge.actors import ReplayBuffer, RLTrainer

from forge.cli.config import parse
from forge.controller import spawn_actors

from forge.services import ReplayBuffer, RLTrainer
from omegaconf import DictConfig

logger = logging.getLogger(__name__)
Expand Down
6 changes: 3 additions & 3 deletions apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.cli.config import parse
from forge.controller import ForgeActor, spawn_actors
from forge.controller import Service, spawn_actors
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
Expand Down Expand Up @@ -53,7 +53,7 @@
logger.setLevel(logging.INFO)


class ForgeSFTRecipe(ForgeActor, ForgeEngine):
class ForgeSFTRecipe(Service, ForgeEngine):
job_config: ForgeJobConfig
train_spec: forge_train_spec.ForgeTrainSpec
parallel_dims: ParallelDims
Expand Down Expand Up @@ -90,7 +90,7 @@ def _init_dist(self):
torchrun normally hands this, but we need to do it ourselves
in monarch for now.

We should consider putting this into ForgeActor, but having this
We should consider putting this into Service, but having this
be explicit for now.

"""
Expand Down
6 changes: 3 additions & 3 deletions apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from functools import partial

import torch
from forge.actors.collector import Collector

from forge.actors.replay_buffer import ReplayBuffer
from forge.interfaces import Environment, Policy
from forge.services.collector import Collector

from forge.services.replay_buffer import ReplayBuffer
from forge.types import Action, Observation, State
from monarch.actor import endpoint, proc_mesh

Expand Down
3 changes: 2 additions & 1 deletion apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import asyncio
from argparse import Namespace

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service

from forge.services.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from vllm.outputs import RequestOutput


Expand Down
6 changes: 3 additions & 3 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .actor import ForgeActor
from .proc_mesh import get_proc_mesh, stop_proc_mesh
from .service import Service


# TODO - remove this once everything has moved to
# service
async def spawn_actors(
name: str,
actor_cls: ForgeActor,
actor_cls: Service,
cfg,
processes,
set_address: bool = False,
Expand All @@ -28,5 +28,5 @@ async def spawn_actors(
"spawn_actors",
"stop_proc_mesh",
"get_proc_mesh",
"ForgeActor",
"Service",
]
40 changes: 20 additions & 20 deletions src/forge/controller/actor.py → src/forge/controller/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger.setLevel(logging.DEBUG)


class ForgeActor(Actor):
class Service(Actor):
def __init__(self, *args, **kwargs):
if not hasattr(self, "_rank"):
self._rank = current_rank().rank
Expand All @@ -43,11 +43,11 @@ def __init__(self, *args, **kwargs):

@endpoint
async def setup(self):
"""Sets up the actor.
"""Sets up the service.

We assume a specific setup function for all actors. The
best practice for actor deployment is to:
1. Pass all data to the actor via the constructor.
We assume a specific setup function for all services. The
best practice for service deployment is to:
1. Pass all data to the service via the constructor.
2. Call setup() to for heavy weight initializations.

This is to ensure that any failures during initialization
Expand All @@ -57,35 +57,35 @@ async def setup(self):
pass

@classmethod
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor":
"""Provisions and deploys a new actor.
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "Service":
"""Provisions and deploys a new service.

This method is used by `Service` to provision a new replica.

We implement it this way because special actors like inference servers
may be composed of multiple actors spawned across multiple processes.
This allows you to specify how your actor gets launched together.
We implement it this way because special services like inference servers
may be composed of multiple services spawned across multiple processes.
This allows you to specify how your service gets launched together.

This implementation is basic, assuming that we're spawning
a homogeneous set of actors on a single proc mesh.
a homogeneous set of services on a single proc mesh.

"""
proc_mesh = await get_proc_mesh(process_config=process_config)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
actor = await proc_mesh.spawn(actor_name, cls, **kwargs)
actor._proc_mesh = proc_mesh
service_name = kwargs.pop("name", cls.__name__)
service = await proc_mesh.spawn(service_name, cls, **kwargs)
service._proc_mesh = proc_mesh

await actor.setup.call()
return actor
await service.setup.call()
return service

@classmethod
async def shutdown(cls, actor: "ForgeActor"):
"""Shuts down an actor.
async def shutdown(cls, service: "Service"):
"""Shuts down an service.

This method is used by `Service` to teardown a replica.
"""
if actor._proc_mesh is None:
if service._proc_mesh is None:
raise AssertionError("Called shutdown on a replica with no proc_mesh.")
await stop_proc_mesh(actor._proc_mesh)
await stop_proc_mesh(service._proc_mesh)
6 changes: 3 additions & 3 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from monarch.actor import ActorError

from forge.controller import ForgeActor
from forge.controller import Service
from forge.types import ProcessConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,11 +102,11 @@ class Replica:

# Configuration for the underlying ProcMesh (scheduler, hosts, GPUs)
proc_config: ProcessConfig
actor_def: type[ForgeActor]
actor_def: type[Service]
actor_kwargs: dict

# The Actor that this replica is running
actor: Optional[ForgeActor] = None
actor: Optional[Service] = None

# Async queue for incoming requests
request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue)
Expand Down
19 changes: 10 additions & 9 deletions src/forge/controller/service/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import logging
from typing import Type

from forge.controller import Service

from monarch.actor import proc_mesh

from forge.controller import ForgeActor
from forge.controller.service import Service, ServiceActor, ServiceConfig

from forge.controller.service.interface import ServiceInterface, ServiceInterfaceV2
Expand All @@ -20,7 +21,7 @@


async def spawn_service(
service_cfg: ServiceConfig, actor_def: Type[ForgeActor], **actor_kwargs
service_cfg: ServiceConfig, actor_def: Type[Service], **actor_kwargs
) -> ServiceInterface:
"""Spawns a service based on the actor class.

Expand All @@ -32,10 +33,10 @@ async def spawn_service(
Returns:
A ServiceInterface that provides access to the Service Actor
"""
# Assert that actor_def is a subclass of ForgeActor
if not issubclass(actor_def, ForgeActor):
# Assert that actor_def is a subclass of Service
if not issubclass(actor_def, Service):
raise TypeError(
f"actor_def must be a subclass of ForgeActor, got {type(actor_def).__name__}"
f"actor_def must be a subclass of Service, got {type(actor_def).__name__}"
)

# Create a single-node proc_mesh and actor_mesh for the Service Actor
Expand All @@ -56,7 +57,7 @@ async def shutdown_service(service: ServiceInterface) -> None:


async def spawn_service_v2(
service_cfg: ServiceConfig, actor_def: Type[ForgeActor], **actor_kwargs
service_cfg: ServiceConfig, actor_def: Type[Service], **actor_kwargs
) -> ServiceInterfaceV2:
"""Spawns a service based on the actor class.

Expand All @@ -68,10 +69,10 @@ async def spawn_service_v2(
Returns:
A ServiceInterface that provides access to the Service Actor
"""
# Assert that actor_def is a subclass of ForgeActor
if not issubclass(actor_def, ForgeActor):
# Assert that actor_def is a subclass of Service
if not issubclass(actor_def, Service):
raise TypeError(
f"actor_def must be a subclass of ForgeActor, got {type(actor_def).__name__}"
f"actor_def must be a subclass of Service, got {type(actor_def).__name__}"
)

# Create a single-node proc_mesh and actor_mesh for the Service Actor
Expand Down
4 changes: 2 additions & 2 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from monarch.actor import endpoint

from forge.controller import ForgeActor
from forge.controller import Service

from forge.types import Action, Message, Observation, Scalar, State

Expand Down Expand Up @@ -74,7 +74,7 @@ def _apply_transform(self, observation: Observation) -> Observation:
return observation


class Policy(ForgeActor, ABC):
class Policy(Service, ABC):
"""Abstract interface for policies."""

@endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __getattr__(name):

return ReplayBuffer
elif name == "TitanRefModel":
from .reference_actor import TitanRefModel
from .reference_service import TitanRefModel

return TitanRefModel
else:
Expand Down
Loading
Loading