diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index fd60ce35c..98e89bab6 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -9,10 +9,10 @@ from dataclasses import dataclass from typing import Any, Callable -from monarch.actor import endpoint - from forge.controller import ForgeActor +from monarch.actor import endpoint + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -38,6 +38,10 @@ async def setup(self) -> None: random.seed(self.seed) self.sampler = random.sample + @endpoint + def get_dp_size(self) -> int: + return self.dp_size + @endpoint async def add(self, episode: "Episode") -> None: self.buffer.append(episode) diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 4463c3f2c..4e4084ccc 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -17,9 +17,8 @@ class TestReplayBuffer: @pytest_asyncio.fixture async def replay_buffer(self) -> ReplayBuffer: - mesh = await proc_mesh(gpus=1) - replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1 + replay_buffer = await ReplayBuffer.options(procs=1).as_actor( + batch_size=2, max_policy_age=1 ) await replay_buffer.setup.call() return replay_buffer @@ -112,12 +111,12 @@ async def test_sample_with_evictions(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_sample_dp_size(self) -> None: """Test that len(samples) == dp_size when sampling.""" - mesh = await proc_mesh(gpus=1) # Create replay buffer with dp_size=3 - replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3 + replay_buffer = await ReplayBuffer.options(procs=1).as_actor( + batch_size=2, max_policy_age=1, dp_size=3 ) await replay_buffer.setup.call() + assert replay_buffer.get_dp_size.call_one().get() == 3 # Add enough trajectories to sample for i in range(10): diff --git a/tests/unit_tests/test_tmp.py b/tests/unit_tests/test_tmp.py new file mode 100644 index 000000000..f18893768 --- /dev/null +++ b/tests/unit_tests/test_tmp.py @@ -0,0 +1,107 @@ +# test_dataclass_service.py + +from dataclasses import dataclass, field + +import pytest +from forge.controller import ForgeActor + +from forge.interfaces import Policy as PolicyInterface +from monarch.actor import endpoint + + +@dataclass +class DataCounter(ForgeActor): + """ForgeActor implemented as a dataclass.""" + + v: int = field(default=1) + + @endpoint + async def value(self) -> int: + return self.v + + +@dataclass +class SimplePolicy(PolicyInterface): + """Minimal concrete policy dataclass.""" + + version: int = field(default=1) + enabled: bool = field(default=True) + + @endpoint + async def generate(self, request): + # Just return the request directly as a "dummy action" + return request + + @endpoint + async def update_weights(self, policy_version: int): + # Store the new version + self.version = policy_version + return self.version + + @endpoint + async def get_enabled(self) -> bool: + """Get the enabled status.""" + return self.enabled + + @endpoint + async def get_version(self) -> int: + """Get the current version.""" + return self.version + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_dataclass_as_service_initialization(): + """Test that dataclass actor can be initialized via as_service().""" + service = await DataCounter.as_service(42) + try: + result = await service.value.choose() + assert result == 42 + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_simple_policy_as_service_and_endpoints(): + """Test that SimplePolicy can be initialized and its endpoints work.""" + + # Start service with initial version=1 and default enabled=True + service = await SimplePolicy.as_service(version=1) + try: + # Check that the default enabled field is True + enabled_status = await service.get_enabled.choose() + assert enabled_status is True + + # Initial version should be 1 + initial_version = await service.get_version.choose() + assert initial_version == 1 + + # Update weights to version 2 + v = await service.update_weights.choose(2) + assert v == 2 + + # Verify version was updated + updated_version = await service.get_version.choose() + assert updated_version == 2 + + # Call generate — should just echo back the input + result = await service.generate.choose("obs") + assert result == "obs" + + finally: + await service.shutdown() + + # Test with explicit enabled=False + service2 = await SimplePolicy.as_service(version=3, enabled=False) + try: + # Check that the enabled field is False + enabled_status2 = await service2.get_enabled.choose() + assert enabled_status2 is False + + # Verify version is set correctly + version2 = await service2.get_version.choose() + assert version2 == 3 + + finally: + await service2.shutdown()