Skip to content

Commit f55b4a4

Browse files
committed
fix replay buffer tests
1 parent 4647052 commit f55b4a4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/unit_tests/test_replay_buffer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ async def test_sample(self, replay_buffer) -> None:
7676
# Test a simple sampling w/ no evictions
7777
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
7878
assert samples is not None
79-
assert len(samples) == 2
79+
assert len(samples[0]) == 2
8080

8181
# Test sampling with overriding batch size
8282
await replay_buffer.add.call_one(trajectory_0)
8383
samples = await replay_buffer.sample.call_one(
8484
curr_policy_version=1, batch_size=1
8585
)
8686
assert samples is not None
87-
assert len(samples) == 1
87+
assert len(samples[0]) == 1
8888

8989
# Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
9090
await replay_buffer.add.call_one(trajectory_0)
@@ -105,6 +105,11 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
105105
curr_policy_version=2, batch_size=1
106106
)
107107
assert samples is not None
108-
assert len(samples) == 1
109-
assert samples[0] == trajectory_1
108+
assert len(samples[0]) == 1
109+
assert samples[0][0] == trajectory_1
110110
replay_buffer.clear.call_one().get()
111+
112+
113+
if __name__ == "__main__":
114+
# Run tests with pytest
115+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)