Skip to content

Commit 3ebe93d

Browse files
authored
[BugFix] remove update from add when transform is not in-place to keep new metadata (#3050)
1 parent 5a5f63d commit 3ebe93d

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,15 +705,16 @@ def add(self, data: Any) -> int:
705705
make_none = False
706706
# Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
707707
is_tc = is_tensor_collection(data)
708-
with data.unsqueeze(-1) if is_tc else contextlib.nullcontext(
709-
data
710-
) as data_unsq:
708+
cm = data.unsqueeze(-1) if is_tc else contextlib.nullcontext(data)
709+
new_data = None
710+
with cm as data_unsq:
711711
data_unsq_r = self._transform.inv(data_unsq)
712712
if is_tc and data_unsq_r is not None:
713713
# this is a no-op whenever the result matches the input
714-
data_unsq.update(data_unsq_r)
714+
new_data = data_unsq_r.squeeze(-1)
715715
else:
716716
make_none = data_unsq_r is None
717+
data = new_data if new_data is not None else data
717718
if make_none:
718719
data = None
719720
if data is None:

torchrl/envs/transforms/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4507,9 +4507,9 @@ class DeviceCastTransform(Transform):
45074507
"""Moves data from one device to another.
45084508
45094509
Args:
4510-
device (torch.device or equivalent): the destination device.
4511-
orig_device (torch.device or equivalent): the origin device. If not specified and
4512-
a parent environment exists, it it retrieved from it. In all other cases,
4510+
device (torch.device or equivalent): the destination device (outside the environment or buffer).
4511+
orig_device (torch.device or equivalent): the origin device (inside the environment or buffer).
4512+
If not specified and a parent environment exists, it it retrieved from it. In all other cases,
45134513
it remains unspecified.
45144514
45154515
Keyword Args:

0 commit comments

Comments
 (0)