Skip to content

Commit 28878c6

Browse files
committed
fixed unit tests
1 parent 1105456 commit 28878c6

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

apps/toy_rl/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from forge.actors.collector import Collector
1818

19-
from forge.data.replay_buffer import ReplayBuffer
19+
from forge.actors.replay_buffer import ReplayBuffer
2020
from forge.interfaces import Environment, Policy
2121
from forge.types import Action, Observation, State
2222
from monarch.actor import endpoint, proc_mesh
@@ -255,7 +255,7 @@ async def replay_buffer_sampler_task():
255255
)
256256

257257
print(
258-
f" Step {i+1:2d}: State={state_value:6.2f} → Action={action_value:6.2f}"
258+
f" Step {i + 1:2d}: State={state_value:6.2f} → Action={action_value:6.2f}"
259259
)
260260

261261
if idx < len(trajectories): # Add spacing between trajectories

src/forge/actors/collector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313
from typing import Callable
1414

15-
from forge.data.replay_buffer import ReplayBuffer
15+
from monarch.actor import Actor, endpoint
16+
17+
from forge.actors.replay_buffer import ReplayBuffer
1618

1719
from forge.interfaces import Policy
1820

1921
from forge.types import Trajectory
2022

21-
from monarch.actor import Actor, endpoint
22-
2323

2424
class Collector(Actor):
2525
"""Collects trajectories for the training loop."""

tests/unit_tests/rl/test_toy_rl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from apps.toy_rl.main import ToyAction, ToyEnvironment, ToyObservation, ToyPolicy
2222
from forge.actors.collector import Collector
23-
from forge.data.replay_buffer import ReplayBuffer
23+
from forge.actors.replay_buffer import ReplayBuffer
2424

2525
# local_proc_mesh is an implementation of proc_mesh for
2626
# testing purposes. It lacks some features of the real proc_mesh
@@ -214,6 +214,7 @@ async def test_full_rl_pipeline_simulation(self):
214214
1, # batch_size
215215
1, # max_policy_age
216216
)
217+
await replay_buffer.setup.call()
217218
collector = await proc.spawn(
218219
"collector",
219220
Collector,

tests/unit_tests/test_replay_buffer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010
import pytest_asyncio
11-
from forge.data.replay_buffer import ReplayBuffer
11+
from forge.actors.replay_buffer import ReplayBuffer
1212
from forge.types import Trajectory
1313

1414
from monarch.actor import proc_mesh
@@ -21,6 +21,7 @@ async def replay_buffer(self) -> ReplayBuffer:
2121
replay_buffer = await mesh.spawn(
2222
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1
2323
)
24+
await replay_buffer.setup.call()
2425
return replay_buffer
2526

2627
@pytest.mark.asyncio

0 commit comments

Comments
 (0)