Skip to content

Commit e38347a

Browse files
authored
[BugFix] Fix shape mismatch in _set_index_in_td with trailing dims of 1 (#3517)
1 parent 83c2101 commit e38347a

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

test/test_libs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
OneHot,
6565
ReplayBuffer,
6666
ReplayBufferEnsemble,
67+
TensorDictReplayBuffer,
6768
Unbounded,
6869
UnboundedDiscrete,
6970
)
@@ -5128,6 +5129,34 @@ def test_collector(self, task, parallel):
51285129
for _ in collector:
51295130
break
51305131

5132+
def test_single_agent_group_replay_buffer(self):
5133+
"""Regression test for gh#3515 - shape mismatch with single-agent group."""
5134+
env = PettingZooEnv(
5135+
task="simple_v3",
5136+
parallel=True,
5137+
seed=0,
5138+
use_mask=False,
5139+
)
5140+
group = list(env.group_map.keys())[0]
5141+
assert len(env.group_map[group]) == 1
5142+
5143+
rollout = env.rollout(10)
5144+
T = rollout.shape[0]
5145+
n_agents = 1
5146+
5147+
# Reshape to (1, T, n_agents) to reproduce the scenario from gh#3515
5148+
# where a replay buffer Transform reshapes collector output to
5149+
# (n_envs, traj_len, n_agents). When n_agents=1 the trailing dim of 1
5150+
# caused _set_index_in_td to match the wrong number of batch dims.
5151+
td = rollout.unsqueeze(0).unsqueeze(-1)
5152+
assert td.shape == torch.Size([1, T, n_agents])
5153+
5154+
rb = TensorDictReplayBuffer(
5155+
storage=LazyTensorStorage(10_000, ndim=3),
5156+
batch_size=4,
5157+
)
5158+
rb.extend(td)
5159+
51315160

51325161
@pytest.mark.skipif(not _has_robohive, reason="RoboHive not found")
51335162
class TestRoboHive:

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@ def _set_index_in_td(self, tensordict, index):
16571657
if _is_int(index):
16581658
index = torch.as_tensor(index, device=tensordict.device)
16591659
elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
1660-
for dim in range(2, tensordict.ndim + 1):
1660+
for dim in range(tensordict.ndim, 1, -1):
16611661
if index.shape[:1].numel() == tensordict.shape[:dim].numel():
16621662
# if index has 2 dims and is in a non-zero format
16631663
index = index.unflatten(0, tensordict.shape[:dim])

torchrl/envs/transforms/transforms.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7920,8 +7920,12 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
79207920
for parent_key in self.truncated_keys:
79217921
parent_truncated = next_tensordict.get(parent_key, None)
79227922
if parent_truncated is not None:
7923-
# Expand parent truncated to match nested shape and apply OR
7924-
expanded = parent_truncated.expand_as(nested_truncated)
7923+
# Insert extra dims (e.g. agent dims) so the parent is
7924+
# broadcastable to the nested agent-level shape.
7925+
parent_val = parent_truncated
7926+
while parent_val.ndim < nested_truncated.ndim:
7927+
parent_val = parent_val.unsqueeze(-2)
7928+
expanded = parent_val.expand_as(nested_truncated)
79257929
next_tensordict.set(nested_key, nested_truncated | expanded)
79267930
break
79277931

@@ -7937,8 +7941,10 @@ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
79377941
for parent_key in self.done_keys:
79387942
parent_done = next_tensordict.get(parent_key, None)
79397943
if parent_done is not None:
7940-
# Expand parent done to match nested shape and apply OR
7941-
expanded = parent_done.expand_as(nested_done)
7944+
parent_val = parent_done
7945+
while parent_val.ndim < nested_done.ndim:
7946+
parent_val = parent_val.unsqueeze(-2)
7947+
expanded = parent_val.expand_as(nested_done)
79427948
next_tensordict.set(nested_key, nested_done | expanded)
79437949
break
79447950

0 commit comments

Comments
 (0)