|
47 | 47 | Composite,
|
48 | 48 | LazyTensorStorage,
|
49 | 49 | NonTensor,
|
| 50 | + RandomSampler, |
50 | 51 | ReplayBuffer,
|
51 | 52 | TensorDictReplayBuffer,
|
52 | 53 | TensorSpec,
|
@@ -13271,6 +13272,52 @@ def test_multistep_transform_changes(self):
|
13271 | 13272 | assert rb[:]["next", "steps"][-1] == data["steps"][-1]
|
13272 | 13273 | assert t._buffer["steps"][-1] == data["steps"][-1]
|
13273 | 13274 |
|
| 13275 | + @pytest.mark.parametrize("add_or_extend", ["add", "extend"]) |
| 13276 | + def test_multisteptransform_single_item(self, add_or_extend): |
| 13277 | + # Configuration |
| 13278 | + buffer_size = 1000 |
| 13279 | + n_step = 3 |
| 13280 | + gamma = 0.99 |
| 13281 | + device = "cpu" |
| 13282 | + |
| 13283 | + rb = ReplayBuffer( |
| 13284 | + storage=LazyTensorStorage(max_size=buffer_size, device=device, ndim=1), |
| 13285 | + sampler=RandomSampler(), |
| 13286 | + transform=MultiStepTransform(n_steps=n_step, gamma=gamma), |
| 13287 | + ) |
| 13288 | + obs_dict = lambda i: {"observation": torch.full((4,), i)} # 4-dim observation |
| 13289 | + next_obs_dict = lambda i: {"observation": torch.full((4,), i)} |
| 13290 | + |
| 13291 | + for i in range(10): |
| 13292 | + # Create transition with batch_size=[] (no batch dimension) |
| 13293 | + transition = TensorDict( |
| 13294 | + { |
| 13295 | + "obs": TensorDict(obs_dict(i), batch_size=[]), |
| 13296 | + "action": torch.full((2,), i), # 2-dim action |
| 13297 | + "next": TensorDict( |
| 13298 | + { |
| 13299 | + "obs": TensorDict(next_obs_dict(i), batch_size=[]), |
| 13300 | + "done": torch.tensor(False, dtype=torch.bool), |
| 13301 | + "reward": torch.tensor(float(i), dtype=torch.float32), |
| 13302 | + }, |
| 13303 | + batch_size=[], |
| 13304 | + ), |
| 13305 | + }, |
| 13306 | + batch_size=[], |
| 13307 | + ) |
| 13308 | + |
| 13309 | + if add_or_extend == "add": |
| 13310 | + rb.add(transition) |
| 13311 | + else: |
| 13312 | + rb.extend(transition.unsqueeze(0)) |
| 13313 | + rbcontent = rb[:] |
| 13314 | + assert (rbcontent["steps_to_next_obs"] == 3).all() |
| 13315 | + assert rbcontent.shape == (7,) |
| 13316 | + assert (rbcontent["next", "original_reward"] == torch.arange(7)).all() |
| 13317 | + assert ( |
| 13318 | + rbcontent["next", "reward"] > rbcontent["next", "original_reward"] |
| 13319 | + ).all() |
| 13320 | + |
13274 | 13321 |
|
13275 | 13322 | class TestBatchSizeTransform(TransformBase):
|
13276 | 13323 | class MyEnv(EnvBase):
|
|
0 commit comments