diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index b53103576..bd61abe82 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -18,8 +18,8 @@ model: processes: scheduler: local # local | mast (not supported yet) - num_hosts: 1 - num_procs: 8 + hosts: 1 + procs: 8 with_gpus: true optimizer: diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index de21b3e9e..866524ac2 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -167,8 +167,8 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] # Once we can create multiple proc meshes on a host mesh, we can ensure # host colocation policy_proc_config = copy(process_config) - policy_proc_config.num_procs = 1 - policy_proc_config.num_hosts = None + policy_proc_config.procs = 1 + policy_proc_config.hosts = None policy_proc_config.with_gpus = False policy_proc = await get_proc_mesh(process_config=policy_proc_config) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 295b0420b..3cc1e6a48 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -8,7 +8,7 @@ import math import sys -from typing import Type, TypeVar +from typing import Any, Type, TypeVar from monarch.actor import Actor, current_rank, current_size, endpoint @@ -22,6 +22,12 @@ class ForgeActor(Actor): + procs: int = 1 + hosts: int | None = None + with_gpus: bool = False + num_replicas: int = 1 + _extra_config: dict[str, Any] = {} + def __init__(self, *args, **kwargs): if not hasattr(self, "_rank"): self._rank = current_rank().rank @@ -48,71 +54,74 @@ def __init__(self, *args, **kwargs): def options( cls: Type[T], *, - service_config: ServiceConfig | None = None, - num_replicas: int | None = None, - procs: int | None = None, - **service_kwargs, + procs: int = 1, + hosts: int | None = None, + with_gpus: bool = False, + num_replicas: int = 1, + **kwargs, ) -> Type[T]: """ - Returns a subclass of this ForgeActor with a bound ServiceConfig. - The returned subclass can later be launched via `.as_service()`. - - Usage (choose ONE of the following forms): - # Option A: construct ServiceConfig implicitly - service = await MyForgeActor.options( - num_replicas=1, - procs=2, - ).as_service(...) - await service.shutdown() - - # Option B: provide an explicit ServiceConfig - cfg = ServiceConfig(num_replicas=1, procs=2, ..) - service = await MyForgeActor.options(service_config=cfg).as_service(...) - await service.shutdown() - - # Option C: skip options, use the default service config with num_replicas=1, procs=1 - service = await MyForgeActor.as_service(...) - await service.shutdown() + Returns a version of ForgeActor with configured resource attributes. + + This method allows you to pre-configure an actor class before spawning it with + `.as_actor()` or `.as_service()`. Each call creates a separate subclass, so + multiple different configurations can coexist without interfering with each other. + + ---- Usage Examples ---- + + # Pre-configure a service with multiple replicas + service = await MyForgeActor.options(num_replicas=2, procs=2).as_service(...) + await service.shutdown() + + # Default usage without calling options + service = await MyForgeActor.as_service(...) + await service.shutdown() + + # Pre-configure a single actor + actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...) + await actor.shutdown() + + # Default usage without calling options + actor = await MyForgeActor.as_actor(...) + await actor.shutdown() """ - if service_config is not None: - cfg = service_config - else: - if num_replicas is None or procs is None: - raise ValueError( - "Must provide either `service_config` or (num_replicas + procs)." - ) - cfg = ServiceConfig( - num_replicas=num_replicas, - procs=procs, - **service_kwargs, - ) - - return type( - f"{cls.__name__}Service", - (cls,), - {"_service_config": cfg}, - ) + attrs = { + "procs": procs, + "hosts": hosts, + "with_gpus": with_gpus, + "num_replicas": num_replicas, + "_extra_config": kwargs, + } + + return type(cls.__name__, (cls,), attrs) @classmethod - async def as_service(cls: Type[T], **actor_kwargs) -> "ServiceInterface": + async def as_service( + cls: Type[T], *actor_args, **actor_kwargs + ) -> "ServiceInterface": """ - Convenience method to spawn this actor as a Service using default configuration. - If `.options()` was called, it will use the bound ServiceConfig; - otherwise defaults to 1 replica, 1 proc. + Spawns this actor as a Service using the configuration stored in `.options()`, + or defaults if `.options()` was not called. + + The configuration values stored in the subclass returned by `.options()` (like + `procs` and `num_replicas`) are used to construct a ServiceConfig instance. + If no configuration was stored, defaults to a single replica with one process. """ # Lazy import to avoid top-level dependency issues from forge.controller.service import Service, ServiceInterface - # Use _service_config if already set by options(), else default - cfg = getattr(cls, "_service_config", None) - if cfg is None: - cfg = ServiceConfig(num_replicas=1, procs=1) - # dynamically create a configured subclass for consistency - cls = type(f"{cls.__name__}Service", (cls,), {"_service_config": cfg}) + cfg_kwargs = { + "procs": cls.procs, + "hosts": cls.hosts, + "with_gpus": cls.with_gpus, + "num_replicas": cls.num_replicas, + **cls._extra_config, # all extra fields + } + cfg = ServiceConfig(**cfg_kwargs) logger.info("Spawning Service Actor for %s", cls.__name__) - service = Service(cfg, cls, actor_kwargs) + service = Service(cfg, cls, actor_args, actor_kwargs) await service.__initialize__() return ServiceInterface(service, cls) @@ -154,7 +163,7 @@ async def set_env(self, addr: str, port: str): os.environ["MASTER_PORT"] = port @classmethod - async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor": + async def launch(cls, *args, **kwargs) -> "ForgeActor": """Provisions and deploys a new actor. This method is used by `Service` to provision a new replica. @@ -167,11 +176,17 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor a homogeneous set of actors on a single proc mesh. """ - proc_mesh = await get_proc_mesh(process_config=process_config) + # Build process config + cfg = ProcessConfig( + procs=cls.procs, + hosts=cls.hosts, + with_gpus=cls.with_gpus, + ) + + proc_mesh = await get_proc_mesh(process_config=cfg) - # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) - actor = await proc_mesh.spawn(actor_name, cls, **kwargs) + actor = await proc_mesh.spawn(actor_name, cls, *args, **kwargs) actor._proc_mesh = proc_mesh if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"): @@ -180,10 +195,22 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor await actor.setup.call() return actor + @classmethod + async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: + """ + Spawns a single actor using the configuration stored in `.options()`, or defaults. + + The configuration values stored in the subclass returned by `.options()` (like + `procs`) are used to construct a ProcessConfig instance. + If no configuration was stored, defaults to a single process with no GPU. + """ + logger.info("Spawning single actor %s", cls.__name__) + actor = await cls.launch(*args, **actor_kwargs) + return actor + @classmethod async def shutdown(cls, actor: "ForgeActor"): """Shuts down an actor. - This method is used by `Service` to teardown a replica. """ if actor._proc_mesh is None: diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5fd42cf40..c85013010 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -244,9 +244,9 @@ def _get_provisioner(): async def get_proc_mesh(config: ProcessConfig) -> ProcMesh: return await _get_provisioner().get_proc_mesh( - num_procs=config.num_procs, + num_procs=config.procs, with_gpus=config.with_gpus, - num_hosts=config.num_hosts, + num_hosts=config.hosts, ) diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index b84e5eec7..6f9f3de72 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -9,7 +9,7 @@ import logging import time from collections import deque -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Optional @@ -103,6 +103,7 @@ class Replica: # Configuration for the underlying ProcMesh (scheduler, hosts, GPUs) proc_config: ProcessConfig actor_def: type[ForgeActor] + actor_args: tuple actor_kwargs: dict # The Actor that this replica is running @@ -157,11 +158,10 @@ async def initialize(self): try: # Deploy the actor and its underlying resources logger.debug(f"Launching actor for replica {self.idx}") - self.actor = await self.actor_def.launch( - process_config=self.proc_config, - **self.actor_kwargs, - ) + self.actor = await self.actor_def.options( + **asdict(self.proc_config) + ).as_actor(*self.actor_args, **self.actor_kwargs) # Transition to healthy state and start processing self.state = ReplicaState.HEALTHY self.start_processing() diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 2b8d8ab9c..66c9c234e 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -82,11 +82,13 @@ def __init__( self, cfg: ServiceConfig, actor_def, + actor_args: tuple, actor_kwargs: dict, ): self._cfg = cfg self._replicas = [] self._actor_def = actor_def + self._actor_args = actor_args self._actor_kwargs = actor_kwargs self._active_sessions = [] @@ -119,6 +121,7 @@ async def __initialize__(self): max_concurrent_requests=self._cfg.replica_max_concurrent_requests, return_first_rank_result=self._cfg.return_first_rank_result, actor_def=self._actor_def, + actor_args=self._actor_args, actor_kwargs=self._actor_kwargs, ) replicas.append(replica) diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py index 5df2352ab..835373d18 100644 --- a/src/forge/data_models/episode.py +++ b/src/forge/data_models/episode.py @@ -8,6 +8,7 @@ from typing import Optional, Sequence import torch + from forge.data_models.scored_completion import ScoredCompletion diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py index 55f538c0e..c7f79c04e 100644 --- a/src/forge/data_models/prompt.py +++ b/src/forge/data_models/prompt.py @@ -7,7 +7,6 @@ from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Any class Role(Enum): diff --git a/src/forge/types.py b/src/forge/types.py index 9de8b8708..cc41d2185 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -91,9 +91,9 @@ class State: class ProcessConfig: """A proc_mesh config for the torchx scheduler.""" - num_procs: int = 1 + procs: int = 1 with_gpus: bool = False - num_hosts: int | None = None + hosts: int | None = None @dataclass @@ -121,12 +121,12 @@ class ServiceConfig: def to_process_config(self) -> ProcessConfig: """Extract ProcessConfig from this ServiceConfig. - Maps procs to num_procs for ProcessConfig. + Maps procs to procs for ProcessConfig. """ return ProcessConfig( - num_procs=self.procs, + procs=self.procs, with_gpus=self.with_gpus, - num_hosts=self.hosts, + hosts=self.hosts, ) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 5fdce0b6a..d265df206 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +from dataclasses import asdict from typing import Callable import pytest @@ -17,7 +18,6 @@ from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig from forge.data.sharding import VLLMSharding - from transformers import AutoModelForCausalLM requires_cuda = pytest.mark.skipif( @@ -262,7 +262,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): policy_config, service_config = get_configs( worker_size=worker_size, tp_size=worker_size, model_name=self.model ) - policy = await Policy.options(service_config=service_config).as_service( + policy = await Policy.options(**asdict(service_config)).as_service( **policy_config ) await policy.update_weights.call() @@ -302,7 +302,7 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(service_config=service_config).as_service( + policy = await Policy.options(**asdict(service_config)).as_service( **policy_config ) await policy.update_weights.call() diff --git a/tests/unit_tests/test_provisioner.py b/tests/unit_tests/test_provisioner.py index 62a68e5a6..888e75e3c 100644 --- a/tests/unit_tests/test_provisioner.py +++ b/tests/unit_tests/test_provisioner.py @@ -11,7 +11,6 @@ import pytest from forge.controller.provisioner import GpuManager, Provisioner -from forge.types import ProcessConfig class TestGpuManagerCudaVisibleDevices: @@ -158,11 +157,10 @@ async def test_get_proc_mesh_respects_cuda_visible_devices(self): # Note - this can run even on CPU because with_gpus just sets environment # variables. - config = ProcessConfig(num_procs=2, with_gpus=True, num_hosts=None) _ = await provisioner.get_proc_mesh( - num_procs=config.num_procs, - with_gpus=config.with_gpus, - num_hosts=config.num_hosts, + num_procs=2, + with_gpus=True, + num_hosts=None, ) # Verify GPUs were allocated from available set remaining_available = local_gpu_manager.get_available_gpus() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index c1161db9c..968f1eea2 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -69,6 +69,7 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: idx=idx, proc_config=ProcessConfig(), actor_def=Counter, + actor_args=(), actor_kwargs={}, ) replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY @@ -76,7 +77,60 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: return replica -# Core Functionality Tests +# Actor Tests + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_as_actor_with_args_config(): + """Test spawning a single actor with passing configs through kwargs.""" + actor = await Counter.options(procs=1).as_actor(5) + + try: + assert await actor.value.choose() == 5 + + # Test increment + await actor.incr.choose() + assert await actor.value.choose() == 6 + + finally: + await Counter.shutdown(actor) + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_as_actor_default_usage(): + """Test spawning a single actor directly via .as_actor() using default config.""" + actor = await Counter.as_actor(v=7) + try: + # Check initial value + assert await actor.value.choose() == 7 + + # Test increment + await actor.incr.choose() + assert await actor.value.choose() == 8 + + finally: + await Counter.shutdown(actor) + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_options_applies_config(): + """Test config via options class.""" + actor_cls = Counter.options(procs=1, with_gpus=True, num_replicas=2) + assert actor_cls.procs == 1 + assert actor_cls.with_gpus is True + assert actor_cls.num_replicas == 2 + + actor = await actor_cls.as_actor(v=3) + try: + assert await actor.value.choose() == 3 + finally: + await Counter.shutdown(actor) + + +# Service Config Tests @pytest.mark.timeout(10) @@ -94,25 +148,10 @@ def __init__(self): await InvalidActor.options(procs=1, num_replicas=1).as_service() -@pytest.mark.timeout(20) -@pytest.mark.asyncio -async def test_service_with_explicit_service_config(): - """Case 1: Provide a ServiceConfig directly.""" - cfg = ServiceConfig(procs=2, num_replicas=3) - service = await Counter.options(service_config=cfg).as_service(v=10) - try: - assert service._service._cfg is cfg - assert service._service._cfg.num_replicas == 3 - assert service._service._cfg.procs == 2 - assert await service.value.choose() == 10 - finally: - await service.shutdown() - - @pytest.mark.timeout(20) @pytest.mark.asyncio async def test_service_with_kwargs_config(): - """Case 2: Construct ServiceConfig implicitly from kwargs.""" + """Construct ServiceConfig implicitly from kwargs.""" service = await Counter.options( num_replicas=4, procs=1, @@ -129,19 +168,11 @@ async def test_service_with_kwargs_config(): await service.shutdown() -@pytest.mark.timeout(20) -@pytest.mark.asyncio -async def test_service_options_missing_args_raises(): - """Case 3: Error if neither service_config nor required args are provided.""" - with pytest.raises(ValueError, match="Must provide either"): - await Counter.options().as_service() # no args, should raise before service spawn - - @pytest.mark.timeout(20) @pytest.mark.asyncio async def test_service_default_config(): - """Case 4: Construct with default configuration using as_service directly.""" - service = await Counter.as_service(v=10) + """Construct with default configuration using as_service directly.""" + service = await Counter.as_service(10) try: cfg = service._service._cfg assert cfg.num_replicas == 1 @@ -151,12 +182,46 @@ async def test_service_default_config(): await service.shutdown() +@pytest.mark.asyncio +@pytest.mark.timeout(20) +async def test_multiple_services_isolated_configs(): + """Ensure multiple services from the same actor class have independent configs.""" + + # Create first service with 2 replicas + service1 = await Counter.options(num_replicas=2, procs=1).as_service(v=10) + + # Create second service with 4 replicas + service2 = await Counter.options(num_replicas=4, procs=1).as_service(v=20) + + try: + # Check that the _service_config objects are independent + cfg1 = service1._service._cfg + cfg2 = service2._service._cfg + + assert cfg1.num_replicas == 2 + assert cfg2.num_replicas == 4 + assert cfg1 is not cfg2 # configs should not be the same object + + # Check actor values + val1 = await service1.value.choose() + val2 = await service2.value.choose() + + assert val1 == 10 + assert val2 == 20 + + finally: + await service1.shutdown() + await service2.shutdown() + + +# Core Functionality Tests + + @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_basic_service_operations(): """Test basic service creation, sessions, and endpoint calls.""" - cfg = ServiceConfig(procs=1, num_replicas=1) - service = await Counter.options(service_config=cfg).as_service(v=0) + service = await Counter.options(procs=1, num_replicas=1).as_service(v=0) try: # Test session creation and uniqueness @@ -437,7 +502,7 @@ async def test_metrics_collection(): @pytest.mark.asyncio async def test_session_stickiness(): """Test that sessions stick to the same replica.""" - service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + service = await Counter.options(procs=1, num_replicas=2).as_service(0) try: session = await service.start_session()