Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from dataclasses import dataclass
from typing import Any, Callable

from monarch.actor import endpoint

from forge.controller import ForgeActor

from monarch.actor import endpoint

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand All @@ -38,6 +38,10 @@ async def setup(self) -> None:
random.seed(self.seed)
self.sampler = random.sample

@endpoint
def get_dp_size(self) -> int:
return self.dp_size

@endpoint
async def add(self, episode: "Episode") -> None:
self.buffer.append(episode)
Expand Down
11 changes: 5 additions & 6 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
class TestReplayBuffer:
@pytest_asyncio.fixture
async def replay_buffer(self) -> ReplayBuffer:
mesh = await proc_mesh(gpus=1)
replay_buffer = await mesh.spawn(
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1
replay_buffer = await ReplayBuffer.options(procs=1).as_actor(
batch_size=2, max_policy_age=1
)
await replay_buffer.setup.call()
return replay_buffer
Expand Down Expand Up @@ -112,12 +111,12 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
@pytest.mark.asyncio
async def test_sample_dp_size(self) -> None:
"""Test that len(samples) == dp_size when sampling."""
mesh = await proc_mesh(gpus=1)
# Create replay buffer with dp_size=3
replay_buffer = await mesh.spawn(
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3
replay_buffer = await ReplayBuffer.options(procs=1).as_actor(
batch_size=2, max_policy_age=1, dp_size=3
)
await replay_buffer.setup.call()
assert replay_buffer.get_dp_size.call_one().get() == 3

# Add enough trajectories to sample
for i in range(10):
Expand Down
107 changes: 107 additions & 0 deletions tests/unit_tests/test_tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# test_dataclass_service.py

from dataclasses import dataclass, field

import pytest
from forge.controller import ForgeActor

from forge.interfaces import Policy as PolicyInterface
from monarch.actor import endpoint


@dataclass
class DataCounter(ForgeActor):
"""ForgeActor implemented as a dataclass."""

v: int = field(default=1)

@endpoint
async def value(self) -> int:
return self.v


@dataclass
class SimplePolicy(PolicyInterface):
"""Minimal concrete policy dataclass."""

version: int = field(default=1)
enabled: bool = field(default=True)

@endpoint
async def generate(self, request):
# Just return the request directly as a "dummy action"
return request

@endpoint
async def update_weights(self, policy_version: int):
# Store the new version
self.version = policy_version
return self.version

@endpoint
async def get_enabled(self) -> bool:
"""Get the enabled status."""
return self.enabled

@endpoint
async def get_version(self) -> int:
"""Get the current version."""
return self.version


@pytest.mark.asyncio
@pytest.mark.timeout(10)
async def test_dataclass_as_service_initialization():
"""Test that dataclass actor can be initialized via as_service()."""
service = await DataCounter.as_service(42)
try:
result = await service.value.choose()
assert result == 42
finally:
await service.shutdown()


@pytest.mark.asyncio
@pytest.mark.timeout(10)
async def test_simple_policy_as_service_and_endpoints():
"""Test that SimplePolicy can be initialized and its endpoints work."""

# Start service with initial version=1 and default enabled=True
service = await SimplePolicy.as_service(version=1)
try:
# Check that the default enabled field is True
enabled_status = await service.get_enabled.choose()
assert enabled_status is True

# Initial version should be 1
initial_version = await service.get_version.choose()
assert initial_version == 1

# Update weights to version 2
v = await service.update_weights.choose(2)
assert v == 2

# Verify version was updated
updated_version = await service.get_version.choose()
assert updated_version == 2

# Call generate — should just echo back the input
result = await service.generate.choose("obs")
assert result == "obs"

finally:
await service.shutdown()

# Test with explicit enabled=False
service2 = await SimplePolicy.as_service(version=3, enabled=False)
try:
# Check that the enabled field is False
enabled_status2 = await service2.get_enabled.choose()
assert enabled_status2 is False

# Verify version is set correctly
version2 = await service2.get_version.choose()
assert version2 == 3

finally:
await service2.shutdown()
Loading