Skip to content

Commit 4914302

Browse files
authored
[BugFix] RB.add unsqueezes tds when applying the transform (#3047)
1 parent 1fc1e16 commit 4914302

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

test/test_transforms.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Composite,
4848
LazyTensorStorage,
4949
NonTensor,
50+
RandomSampler,
5051
ReplayBuffer,
5152
TensorDictReplayBuffer,
5253
TensorSpec,
@@ -13271,6 +13272,52 @@ def test_multistep_transform_changes(self):
1327113272
assert rb[:]["next", "steps"][-1] == data["steps"][-1]
1327213273
assert t._buffer["steps"][-1] == data["steps"][-1]
1327313274

13275+
@pytest.mark.parametrize("add_or_extend", ["add", "extend"])
13276+
def test_multisteptransform_single_item(self, add_or_extend):
13277+
# Configuration
13278+
buffer_size = 1000
13279+
n_step = 3
13280+
gamma = 0.99
13281+
device = "cpu"
13282+
13283+
rb = ReplayBuffer(
13284+
storage=LazyTensorStorage(max_size=buffer_size, device=device, ndim=1),
13285+
sampler=RandomSampler(),
13286+
transform=MultiStepTransform(n_steps=n_step, gamma=gamma),
13287+
)
13288+
obs_dict = lambda i: {"observation": torch.full((4,), i)} # 4-dim observation
13289+
next_obs_dict = lambda i: {"observation": torch.full((4,), i)}
13290+
13291+
for i in range(10):
13292+
# Create transition with batch_size=[] (no batch dimension)
13293+
transition = TensorDict(
13294+
{
13295+
"obs": TensorDict(obs_dict(i), batch_size=[]),
13296+
"action": torch.full((2,), i), # 2-dim action
13297+
"next": TensorDict(
13298+
{
13299+
"obs": TensorDict(next_obs_dict(i), batch_size=[]),
13300+
"done": torch.tensor(False, dtype=torch.bool),
13301+
"reward": torch.tensor(float(i), dtype=torch.float32),
13302+
},
13303+
batch_size=[],
13304+
),
13305+
},
13306+
batch_size=[],
13307+
)
13308+
13309+
if add_or_extend == "add":
13310+
rb.add(transition)
13311+
else:
13312+
rb.extend(transition.unsqueeze(0))
13313+
rbcontent = rb[:]
13314+
assert (rbcontent["steps_to_next_obs"] == 3).all()
13315+
assert rbcontent.shape == (7,)
13316+
assert (rbcontent["next", "original_reward"] == torch.arange(7)).all()
13317+
assert (
13318+
rbcontent["next", "reward"] > rbcontent["next", "original_reward"]
13319+
).all()
13320+
1327413321

1327513322
class TestBatchSizeTransform(TransformBase):
1327613323
class MyEnv(EnvBase):

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,18 @@ def add(self, data: Any) -> int:
702702
"""
703703
if self._transform is not None and len(self._transform):
704704
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
705-
data = self._transform.inv(data)
705+
make_none = False
706+
# Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
707+
is_tc = is_tensor_collection(data)
708+
with data.unsqueeze(-1) if is_tc else contextlib.nullcontext(data) as data_unsq:
709+
data_unsq_r = self._transform.inv(data_unsq)
710+
if is_tc and data_unsq_r is not None:
711+
# this is a no-op whenever the result matches the input
712+
data_unsq.update(data_unsq_r)
713+
else:
714+
make_none = data_unsq_r is None
715+
if make_none:
716+
data = None
706717
if data is None:
707718
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
708719
return self._add(data)

0 commit comments

Comments
 (0)