File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -705,15 +705,16 @@ def add(self, data: Any) -> int:
705
705
make_none = False
706
706
# Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
707
707
is_tc = is_tensor_collection (data )
708
- with data .unsqueeze (- 1 ) if is_tc else contextlib .nullcontext (
709
- data
710
- ) as data_unsq :
708
+ cm = data .unsqueeze (- 1 ) if is_tc else contextlib .nullcontext (data )
709
+ new_data = None
710
+ with cm as data_unsq :
711
711
data_unsq_r = self ._transform .inv (data_unsq )
712
712
if is_tc and data_unsq_r is not None :
713
713
# this is a no-op whenever the result matches the input
714
- data_unsq . update ( data_unsq_r )
714
+ new_data = data_unsq_r . squeeze ( - 1 )
715
715
else :
716
716
make_none = data_unsq_r is None
717
+ data = new_data if new_data is not None else data
717
718
if make_none :
718
719
data = None
719
720
if data is None :
Original file line number Diff line number Diff line change @@ -4507,9 +4507,9 @@ class DeviceCastTransform(Transform):
4507
4507
"""Moves data from one device to another.
4508
4508
4509
4509
Args:
4510
- device (torch.device or equivalent): the destination device.
4511
- orig_device (torch.device or equivalent): the origin device. If not specified and
4512
- a parent environment exists, it it retrieved from it. In all other cases,
4510
+ device (torch.device or equivalent): the destination device (outside the environment or buffer) .
4511
+ orig_device (torch.device or equivalent): the origin device (inside the environment or buffer).
4512
+ If not specified and a parent environment exists, it it retrieved from it. In all other cases,
4513
4513
it remains unspecified.
4514
4514
4515
4515
Keyword Args:
You can’t perform that action at this time.
0 commit comments