@@ -277,10 +277,8 @@ def __init__(
277277
278278 if resolution is not None and stride is not None :
279279 feat_size = resolution // stride
280- coords_h , coords_w = self ._get_coords (
281- feat_size , feat_size , device = "cuda"
282- )
283- self .compilable_cord_cache = (coords_h , coords_w )
280+ # Build cache lazily on the actual reference_boxes device to avoid
281+ # mixing CPU/GPU tensors when users run on CPU.
284282 self .compilable_stored_size = (feat_size , feat_size )
285283
286284 self .roi_pooler = (
@@ -332,7 +330,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
332330 H , W = feat_size
333331 boxes_xyxy = box_cxcywh_to_xyxy (reference_boxes ).transpose (0 , 1 )
334332 bs , num_queries , _ = boxes_xyxy .shape
335- if self .compilable_cord_cache is None :
333+ if (
334+ self .compilable_cord_cache is None
335+ or any (c .device != reference_boxes .device for c in self .compilable_cord_cache )
336+ or self .compilable_stored_size != (H , W )
337+ ):
336338 self .compilable_cord_cache = self ._get_coords (H , W , reference_boxes .device )
337339 self .compilable_stored_size = (H , W )
338340
@@ -345,11 +347,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
345347 else :
346348 # cache miss, will create compilation issue
347349 # In case we're not compiling, we'll still rely on the dict-based cache
348- if feat_size not in self .coord_cache :
349- self .coord_cache [feat_size ] = self ._get_coords (
350+ cache_key = (feat_size , reference_boxes .device )
351+ if cache_key not in self .coord_cache :
352+ self .coord_cache [cache_key ] = self ._get_coords (
350353 H , W , reference_boxes .device
351354 )
352- coords_h , coords_w = self .coord_cache [feat_size ]
355+ coords_h , coords_w = self .coord_cache [cache_key ]
353356
354357 assert coords_h .shape == (H ,)
355358 assert coords_w .shape == (W ,)
0 commit comments