@@ -653,6 +653,8 @@ def erase_memoize_cache(self) -> None:
653653 def __getstate__ (self ):
654654 state = dict (self .__dict__ )
655655 state ["_encode" ] = {}
656+ # Clear device-specific bounds cache to avoid serializing CUDA tensors
657+ state .pop ("_bounds_cache" , None )
656658 return state
657659
658660 @classmethod
@@ -2564,16 +2566,24 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
25642566 r = r .to (self .device )
25652567 return r
25662568
2569+ def _get_space_bounds (
2570+ self , device : torch .device
2571+ ) -> tuple [torch .Tensor , torch .Tensor ]:
2572+ """Get space bounds on the specified device, using cache to avoid .to() during CUDA graph capture."""
2573+ if self .device == device :
2574+ return self .space .low , self .space .high
2575+ cache = self .__dict__ .get ("_bounds_cache" )
2576+ if cache is None :
2577+ cache = self .__dict__ ["_bounds_cache" ] = {}
2578+ if device not in cache :
2579+ cache [device ] = (self .space .low .to (device ), self .space .high .to (device ))
2580+ return cache [device ]
2581+
25672582 def _project (self , val : torch .Tensor ) -> torch .Tensor :
2568- low = self .space .low
2569- high = self .space .high
2570- if self .device != val .device :
2571- low = low .to (val .device )
2572- high = high .to (val .device )
2583+ low , high = self ._get_space_bounds (val .device )
25732584 low = low .expand_as (val )
25742585 high = high .expand_as (val )
2575- val = torch .clamp (val , low , high )
2576- return val
2586+ return torch .clamp (val , low , high )
25772587
25782588 def is_in (self , val : torch .Tensor ) -> bool :
25792589 val_shape = _remove_neg_shapes (tensordict .utils ._shape (val ))
@@ -2590,15 +2600,14 @@ def is_in(self, val: torch.Tensor) -> bool:
25902600 if not dtype_match :
25912601 return False
25922602 try :
2593- within_bounds = (val >= self .space .low .to (val .device )).all () and (
2594- val <= self .space .high .to (val .device )
2595- ).all ()
2603+ low , high = self ._get_space_bounds (val .device )
2604+ within_bounds = (val >= low ).all () and (val <= high ).all ()
25962605 return within_bounds
25972606 except NotImplementedError :
2607+ low , high = self ._get_space_bounds (val .device )
25982608 within_bounds = all (
2599- (_val >= space .low .to (val .device )).all ()
2600- and (_val <= space .high .to (val .device )).all ()
2601- for (_val , space ) in zip (val , self .space .unbind (0 ))
2609+ (_val >= _low ).all () and (_val <= _high ).all ()
2610+ for (_val , _low , _high ) in zip (val , low .unbind (0 ), high .unbind (0 ))
26022611 )
26032612 return within_bounds
26042613 except RuntimeError as err :
0 commit comments