diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index df79c302e..9bdde3ce6 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,11 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from monarch.actor import endpoint - -from forge.controller import ForgeActor - -from forge.types import Action, Message, Observation, Scalar, State +from forge.types import Message, Observation, Scalar class Transform(ABC): @@ -37,63 +33,6 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): - """Abstract base class for environments. - - Args: - transform: Optional transform that modifies observations, typically to add rewards. - Can be a Transform instance or a callable for backward compatibility. - """ - - def __init__( - self, - transform: Transform | None = None, - ): - self.transform = transform - - @abstractmethod - def reset(self) -> Observation: - """Reset the environment and return an initial observation.""" - pass - - @abstractmethod - def step(self, action: Any) -> Observation: - """Take a step in the environment and return an observation.""" - pass - - @property - @abstractmethod - def state(self) -> State: - """Get the current state of the environment.""" - pass - - def _apply_transform(self, observation: Observation) -> Observation: - """Apply the transform to an observation if one is provided.""" - if self.transform is not None: - return self.transform(observation) - return observation - - -class Policy(ForgeActor, ABC): - """Abstract interface for policies.""" - - @endpoint - @abstractmethod - async def generate(self, request: Observation) -> Action: - """Generate an action given a state/request.""" - pass - - @endpoint - @abstractmethod - async def update_weights(self, policy_version: int): - """Update the policy weights. - - Args: - policy_version: The version number to update to. - """ - pass - - class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. @@ -210,10 +149,3 @@ class Reward(ABC): def __call__(self, observation: Observation) -> float: """Compute a reward for an observation.""" pass - - -# TODO -# class RLLoss(ABC): - -# class SFTLoss(ABC): # inherit from titan loss -# from torchtitan.components.loss import LossFunction diff --git a/src/forge/types.py b/src/forge/types.py index 6a9dcc122..fa77a83de 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -15,15 +15,6 @@ class Message(TypedDict): tools: dict[str, Any] | None -@dataclass -class ForgeEnvInfo: - """Environment info returned with observations.""" - - episode_id: str | None = None - step_count: int = 0 - metadata: dict | None = None - - @dataclass(kw_only=True) class Observation: """Base class for environment observations. @@ -44,50 +35,6 @@ class Observation: metadata: dict[str, Any] = field(default_factory=dict) -@dataclass(kw_only=True) -class Action: - """Base class for environment actions. - - Contract: - - Should contain all information needed to execute a step in the environment - - Should be serializable/deserializable - - Should be immutable (or treated as such) - - Args: - metadata: Additional data that may be useful for logging, debugging, or transforms - """ - - metadata: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class Trajectory: - """A trajectory containing a sequence of states, actions, etc.""" - - policy_version: int - states: list[Observation] = field(default_factory=list) - actions: list[Action] = field(default_factory=list) - - def __post_init__(self): - assert self.policy_version >= 0 - - -@dataclass(kw_only=True) -class State: - """Base class for environment state. - - Contract: - - Should contain all information needed to restore the environment - - Should be serializable/deserializable - - May contain information not exposed in observations - - Args: - metadata: Additional state information that may be useful for debugging or analysis - """ - - metadata: dict[str, Any] = field(default_factory=dict) - - class Launcher(Enum): MAST = "mast" SLURM = "slurm" diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index e6c6876c3..10053b78f 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -6,10 +6,25 @@ """Test for data/replay_buffer.py""" +from dataclasses import dataclass + import pytest import pytest_asyncio from forge.actors.replay_buffer import ReplayBuffer -from forge.types import Trajectory + + +@dataclass +class TestEpisode: + """ + Dummy Episode containing just a policy version + + ReplayBuffer expects any construct (typically an Episode) that contains a + `policy_version`. + + TODO: Replaced with a unified interface in the future. + """ + + policy_version: int class TestReplayBuffer: @@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer: @pytest.mark.asyncio async def test_add(self, replay_buffer: ReplayBuffer) -> None: - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) assert replay_buffer._numel.call_one().get() == 1 - assert replay_buffer._getitem.call_one(0).get() == trajectory + assert replay_buffer._getitem.call_one(0).get() == episode replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_add_multiple(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 - assert replay_buffer._getitem.call_one(0).get() == trajectory_0 - assert replay_buffer._getitem.call_one(1).get() == trajectory_1 + assert replay_buffer._getitem.call_one(0).get() == episode_0 + assert replay_buffer._getitem.call_one(1).get() == episode_1 replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_state_dict_save_load(self, replay_buffer) -> None: - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) state_dict = replay_buffer.state_dict.call_one().get() replay_buffer.clear.call_one().get() assert replay_buffer._numel.call_one().get() == 0 @@ -53,10 +68,10 @@ async def test_state_dict_save_load(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_evict(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 await replay_buffer.evict.call_one(curr_policy_version=2) assert replay_buffer._numel.call_one().get() == 1 @@ -64,10 +79,10 @@ async def test_evict(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_sample(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 # Test a simple sampling @@ -77,19 +92,19 @@ async def test_sample(self, replay_buffer) -> None: assert replay_buffer._numel.call_one().get() == 2 # Test sampling (not enough samples in buffer, returns None) - await replay_buffer.add.call_one(trajectory_0) + await replay_buffer.add.call_one(episode_0) samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is None replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_sample_with_evictions(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - trajectory_2 = Trajectory(policy_version=2) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) - await replay_buffer.add.call_one(trajectory_2) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + episode_2 = TestEpisode(policy_version=2) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) + await replay_buffer.add.call_one(episode_2) assert replay_buffer._numel.call_one().get() == 3 samples = await replay_buffer.sample.call_one( curr_policy_version=2, @@ -112,8 +127,8 @@ async def test_sample_dp_size(self) -> None: # Add enough trajectories to sample for i in range(10): - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) # Sample and verify len(samples) == dp_size samples = await replay_buffer.sample.call_one(curr_policy_version=0)