Skip to content

Commit df4dc75

Browse files
authored
Add files via upload
1 parent 148cffa commit df4dc75

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

sam3/sam/transformer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,8 @@ def __init__(
282282
self.compute_cis = partial(
283283
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
284284
)
285-
device = torch.device("cuda") if torch.cuda.is_available() else None
286285
self.freqs_cis = self.compute_cis(
287-
end_x=feat_sizes[0], end_y=feat_sizes[1], device=device
286+
end_x=feat_sizes[0], end_y=feat_sizes[1], device=None
288287
)
289288
if self.use_rope_real:
290289
self.freqs_cis_real = self.freqs_cis.real
@@ -306,10 +305,14 @@ def forward(
306305

307306
# Apply rotary position encoding
308307
w = h = math.sqrt(q.shape[-2])
309-
if self.freqs_cis.shape[0] != q.shape[-2]:
308+
if (
309+
self.freqs_cis.device != q.device
310+
or self.freqs_cis.shape[0] != q.shape[-2]
311+
):
310312
self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device)
311-
self.freqs_cis_real = self.freqs_cis.real
312-
self.freqs_cis_imag = self.freqs_cis.imag
313+
if self.use_rope_real:
314+
self.freqs_cis_real = self.freqs_cis.real
315+
self.freqs_cis_imag = self.freqs_cis.imag
313316
if q.shape[-2] != k.shape[-2]:
314317
assert self.rope_k_repeat
315318

0 commit comments

Comments
 (0)