Skip to content

Commit fca4e76

Browse files
authored
[BugFix] Fix CUDA graph capture for Bounded spec projection (#3453)
1 parent 190a43d commit fca4e76

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

torchrl/data/tensor_specs.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,8 @@ def erase_memoize_cache(self) -> None:
653653
def __getstate__(self):
654654
state = dict(self.__dict__)
655655
state["_encode"] = {}
656+
# Clear device-specific bounds cache to avoid serializing CUDA tensors
657+
state.pop("_bounds_cache", None)
656658
return state
657659

658660
@classmethod
@@ -2564,16 +2566,24 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
25642566
r = r.to(self.device)
25652567
return r
25662568

2569+
def _get_space_bounds(
2570+
self, device: torch.device
2571+
) -> tuple[torch.Tensor, torch.Tensor]:
2572+
"""Get space bounds on the specified device, using cache to avoid .to() during CUDA graph capture."""
2573+
if self.device == device:
2574+
return self.space.low, self.space.high
2575+
cache = self.__dict__.get("_bounds_cache")
2576+
if cache is None:
2577+
cache = self.__dict__["_bounds_cache"] = {}
2578+
if device not in cache:
2579+
cache[device] = (self.space.low.to(device), self.space.high.to(device))
2580+
return cache[device]
2581+
25672582
def _project(self, val: torch.Tensor) -> torch.Tensor:
2568-
low = self.space.low
2569-
high = self.space.high
2570-
if self.device != val.device:
2571-
low = low.to(val.device)
2572-
high = high.to(val.device)
2583+
low, high = self._get_space_bounds(val.device)
25732584
low = low.expand_as(val)
25742585
high = high.expand_as(val)
2575-
val = torch.clamp(val, low, high)
2576-
return val
2586+
return torch.clamp(val, low, high)
25772587

25782588
def is_in(self, val: torch.Tensor) -> bool:
25792589
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
@@ -2590,15 +2600,14 @@ def is_in(self, val: torch.Tensor) -> bool:
25902600
if not dtype_match:
25912601
return False
25922602
try:
2593-
within_bounds = (val >= self.space.low.to(val.device)).all() and (
2594-
val <= self.space.high.to(val.device)
2595-
).all()
2603+
low, high = self._get_space_bounds(val.device)
2604+
within_bounds = (val >= low).all() and (val <= high).all()
25962605
return within_bounds
25972606
except NotImplementedError:
2607+
low, high = self._get_space_bounds(val.device)
25982608
within_bounds = all(
2599-
(_val >= space.low.to(val.device)).all()
2600-
and (_val <= space.high.to(val.device)).all()
2601-
for (_val, space) in zip(val, self.space.unbind(0))
2609+
(_val >= _low).all() and (_val <= _high).all()
2610+
for (_val, _low, _high) in zip(val, low.unbind(0), high.unbind(0))
26022611
)
26032612
return within_bounds
26042613
except RuntimeError as err:

0 commit comments

Comments
 (0)