Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 1 addition & 69 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping

from monarch.actor import endpoint

from forge.controller import ForgeActor

from forge.types import Action, Message, Observation, Scalar, State
from forge.types import Message, Observation, Scalar


class Transform(ABC):
Expand All @@ -37,63 +33,6 @@ def __call__(self, observation: Observation) -> Observation:
pass


class Environment(ABC):
"""Abstract base class for environments.

Args:
transform: Optional transform that modifies observations, typically to add rewards.
Can be a Transform instance or a callable for backward compatibility.
"""

def __init__(
self,
transform: Transform | None = None,
):
self.transform = transform

@abstractmethod
def reset(self) -> Observation:
"""Reset the environment and return an initial observation."""
pass

@abstractmethod
def step(self, action: Any) -> Observation:
"""Take a step in the environment and return an observation."""
pass

@property
@abstractmethod
def state(self) -> State:
"""Get the current state of the environment."""
pass

def _apply_transform(self, observation: Observation) -> Observation:
"""Apply the transform to an observation if one is provided."""
if self.transform is not None:
return self.transform(observation)
return observation


class Policy(ForgeActor, ABC):
"""Abstract interface for policies."""

@endpoint
@abstractmethod
async def generate(self, request: Observation) -> Action:
"""Generate an action given a state/request."""
pass

@endpoint
@abstractmethod
async def update_weights(self, policy_version: int):
"""Update the policy weights.

Args:
policy_version: The version number to update to.
"""
pass


class BaseTokenizer(ABC):
"""
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
Expand Down Expand Up @@ -210,10 +149,3 @@ class Reward(ABC):
def __call__(self, observation: Observation) -> float:
"""Compute a reward for an observation."""
pass


# TODO
# class RLLoss(ABC):

# class SFTLoss(ABC): # inherit from titan loss
# from torchtitan.components.loss import LossFunction
53 changes: 0 additions & 53 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@ class Message(TypedDict):
tools: dict[str, Any] | None


@dataclass
class ForgeEnvInfo:
"""Environment info returned with observations."""

episode_id: str | None = None
step_count: int = 0
metadata: dict | None = None


@dataclass(kw_only=True)
class Observation:
"""Base class for environment observations.
Expand All @@ -44,50 +35,6 @@ class Observation:
metadata: dict[str, Any] = field(default_factory=dict)


@dataclass(kw_only=True)
class Action:
"""Base class for environment actions.

Contract:
- Should contain all information needed to execute a step in the environment
- Should be serializable/deserializable
- Should be immutable (or treated as such)

Args:
metadata: Additional data that may be useful for logging, debugging, or transforms
"""

metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class Trajectory:
"""A trajectory containing a sequence of states, actions, etc."""

policy_version: int
states: list[Observation] = field(default_factory=list)
actions: list[Action] = field(default_factory=list)

def __post_init__(self):
assert self.policy_version >= 0


@dataclass(kw_only=True)
class State:
"""Base class for environment state.

Contract:
- Should contain all information needed to restore the environment
- Should be serializable/deserializable
- May contain information not exposed in observations

Args:
metadata: Additional state information that may be useful for debugging or analysis
"""

metadata: dict[str, Any] = field(default_factory=dict)


class Launcher(Enum):
MAST = "mast"
SLURM = "slurm"
Expand Down
73 changes: 44 additions & 29 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,25 @@

"""Test for data/replay_buffer.py"""

from dataclasses import dataclass

import pytest
import pytest_asyncio
from forge.actors.replay_buffer import ReplayBuffer
from forge.types import Trajectory


@dataclass
class TestEpisode:
"""
Dummy Episode containing just a policy version

ReplayBuffer expects any construct (typically an Episode) that contains a
`policy_version`.

TODO: Replaced with a unified interface in the future.
"""

policy_version: int


class TestReplayBuffer:
Expand All @@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer:

@pytest.mark.asyncio
async def test_add(self, replay_buffer: ReplayBuffer) -> None:
trajectory = Trajectory(policy_version=0)
await replay_buffer.add.call_one(trajectory)
episode = TestEpisode(policy_version=0)
await replay_buffer.add.call_one(episode)
assert replay_buffer._numel.call_one().get() == 1
assert replay_buffer._getitem.call_one(0).get() == trajectory
assert replay_buffer._getitem.call_one(0).get() == episode
replay_buffer.clear.call_one().get()

@pytest.mark.asyncio
async def test_add_multiple(self, replay_buffer) -> None:
trajectory_0 = Trajectory(policy_version=0)
trajectory_1 = Trajectory(policy_version=1)
await replay_buffer.add.call_one(trajectory_0)
await replay_buffer.add.call_one(trajectory_1)
episode_0 = TestEpisode(policy_version=0)
episode_1 = TestEpisode(policy_version=1)
await replay_buffer.add.call_one(episode_0)
await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2
assert replay_buffer._getitem.call_one(0).get() == trajectory_0
assert replay_buffer._getitem.call_one(1).get() == trajectory_1
assert replay_buffer._getitem.call_one(0).get() == episode_0
assert replay_buffer._getitem.call_one(1).get() == episode_1
replay_buffer.clear.call_one().get()

@pytest.mark.asyncio
async def test_state_dict_save_load(self, replay_buffer) -> None:
trajectory = Trajectory(policy_version=0)
await replay_buffer.add.call_one(trajectory)
episode = TestEpisode(policy_version=0)
await replay_buffer.add.call_one(episode)
state_dict = replay_buffer.state_dict.call_one().get()
replay_buffer.clear.call_one().get()
assert replay_buffer._numel.call_one().get() == 0
Expand All @@ -53,21 +68,21 @@ async def test_state_dict_save_load(self, replay_buffer) -> None:

@pytest.mark.asyncio
async def test_evict(self, replay_buffer) -> None:
trajectory_0 = Trajectory(policy_version=0)
trajectory_1 = Trajectory(policy_version=1)
await replay_buffer.add.call_one(trajectory_0)
await replay_buffer.add.call_one(trajectory_1)
episode_0 = TestEpisode(policy_version=0)
episode_1 = TestEpisode(policy_version=1)
await replay_buffer.add.call_one(episode_0)
await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2
await replay_buffer.evict.call_one(curr_policy_version=2)
assert replay_buffer._numel.call_one().get() == 1
replay_buffer.clear.call_one().get()

@pytest.mark.asyncio
async def test_sample(self, replay_buffer) -> None:
trajectory_0 = Trajectory(policy_version=0)
trajectory_1 = Trajectory(policy_version=1)
await replay_buffer.add.call_one(trajectory_0)
await replay_buffer.add.call_one(trajectory_1)
episode_0 = TestEpisode(policy_version=0)
episode_1 = TestEpisode(policy_version=1)
await replay_buffer.add.call_one(episode_0)
await replay_buffer.add.call_one(episode_1)
assert replay_buffer._numel.call_one().get() == 2

# Test a simple sampling
Expand All @@ -77,19 +92,19 @@ async def test_sample(self, replay_buffer) -> None:
assert replay_buffer._numel.call_one().get() == 2

# Test sampling (not enough samples in buffer, returns None)
await replay_buffer.add.call_one(trajectory_0)
await replay_buffer.add.call_one(episode_0)
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is None
replay_buffer.clear.call_one().get()

@pytest.mark.asyncio
async def test_sample_with_evictions(self, replay_buffer) -> None:
trajectory_0 = Trajectory(policy_version=0)
trajectory_1 = Trajectory(policy_version=1)
trajectory_2 = Trajectory(policy_version=2)
await replay_buffer.add.call_one(trajectory_0)
await replay_buffer.add.call_one(trajectory_1)
await replay_buffer.add.call_one(trajectory_2)
episode_0 = TestEpisode(policy_version=0)
episode_1 = TestEpisode(policy_version=1)
episode_2 = TestEpisode(policy_version=2)
await replay_buffer.add.call_one(episode_0)
await replay_buffer.add.call_one(episode_1)
await replay_buffer.add.call_one(episode_2)
assert replay_buffer._numel.call_one().get() == 3
samples = await replay_buffer.sample.call_one(
curr_policy_version=2,
Expand All @@ -112,8 +127,8 @@ async def test_sample_dp_size(self) -> None:

# Add enough trajectories to sample
for i in range(10):
trajectory = Trajectory(policy_version=0)
await replay_buffer.add.call_one(trajectory)
episode = TestEpisode(policy_version=0)
await replay_buffer.add.call_one(episode)

# Sample and verify len(samples) == dp_size
samples = await replay_buffer.sample.call_one(curr_policy_version=0)
Expand Down
Loading