@@ -76,15 +76,15 @@ async def test_sample(self, replay_buffer) -> None:
7676 # Test a simple sampling w/ no evictions
7777 samples = await replay_buffer .sample .call_one (curr_policy_version = 1 )
7878 assert samples is not None
79- assert len (samples ) == 2
79+ assert len (samples [ 0 ] ) == 2
8080
8181 # Test sampling with overriding batch size
8282 await replay_buffer .add .call_one (trajectory_0 )
8383 samples = await replay_buffer .sample .call_one (
8484 curr_policy_version = 1 , batch_size = 1
8585 )
8686 assert samples is not None
87- assert len (samples ) == 1
87+ assert len (samples [ 0 ] ) == 1
8888
8989 # Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
9090 await replay_buffer .add .call_one (trajectory_0 )
@@ -105,6 +105,11 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
105105 curr_policy_version = 2 , batch_size = 1
106106 )
107107 assert samples is not None
108- assert len (samples ) == 1
109- assert samples [0 ] == trajectory_1
108+ assert len (samples [ 0 ] ) == 1
109+ assert samples [0 ][ 0 ] == trajectory_1
110110 replay_buffer .clear .call_one ().get ()
111+
112+
113+ if __name__ == "__main__" :
114+ # Run tests with pytest
115+ pytest .main ([__file__ , "-v" ])
0 commit comments