Skip to content

Commit 148cffa

Browse files
authored
Add files via upload
1 parent ce0d2c5 commit 148cffa

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

sam3/model/decoder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,8 @@ def __init__(
277277

278278
if resolution is not None and stride is not None:
279279
feat_size = resolution // stride
280-
coords_h, coords_w = self._get_coords(
281-
feat_size, feat_size, device="cuda"
282-
)
283-
self.compilable_cord_cache = (coords_h, coords_w)
280+
# Build cache lazily on the actual reference_boxes device to avoid
281+
# mixing CPU/GPU tensors when users run on CPU.
284282
self.compilable_stored_size = (feat_size, feat_size)
285283

286284
self.roi_pooler = (
@@ -332,7 +330,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
332330
H, W = feat_size
333331
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
334332
bs, num_queries, _ = boxes_xyxy.shape
335-
if self.compilable_cord_cache is None:
333+
if (
334+
self.compilable_cord_cache is None
335+
or any(c.device != reference_boxes.device for c in self.compilable_cord_cache)
336+
or self.compilable_stored_size != (H, W)
337+
):
336338
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
337339
self.compilable_stored_size = (H, W)
338340

@@ -345,11 +347,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
345347
else:
346348
# cache miss, will create compilation issue
347349
# In case we're not compiling, we'll still rely on the dict-based cache
348-
if feat_size not in self.coord_cache:
349-
self.coord_cache[feat_size] = self._get_coords(
350+
cache_key = (feat_size, reference_boxes.device)
351+
if cache_key not in self.coord_cache:
352+
self.coord_cache[cache_key] = self._get_coords(
350353
H, W, reference_boxes.device
351354
)
352-
coords_h, coords_w = self.coord_cache[feat_size]
355+
coords_h, coords_w = self.coord_cache[cache_key]
353356

354357
assert coords_h.shape == (H,)
355358
assert coords_w.shape == (W,)

0 commit comments

Comments
 (0)