@@ -282,9 +282,8 @@ def __init__(
282282 self .compute_cis = partial (
283283 compute_axial_cis , dim = self .internal_dim // self .num_heads , theta = rope_theta
284284 )
285- device = torch .device ("cuda" ) if torch .cuda .is_available () else None
286285 self .freqs_cis = self .compute_cis (
287- end_x = feat_sizes [0 ], end_y = feat_sizes [1 ], device = device
286+ end_x = feat_sizes [0 ], end_y = feat_sizes [1 ], device = None
288287 )
289288 if self .use_rope_real :
290289 self .freqs_cis_real = self .freqs_cis .real
@@ -306,10 +305,14 @@ def forward(
306305
307306 # Apply rotary position encoding
308307 w = h = math .sqrt (q .shape [- 2 ])
309- if self .freqs_cis .shape [0 ] != q .shape [- 2 ]:
308+ if (
309+ self .freqs_cis .device != q .device
310+ or self .freqs_cis .shape [0 ] != q .shape [- 2 ]
311+ ):
310312 self .freqs_cis = self .compute_cis (end_x = w , end_y = h , device = q .device )
311- self .freqs_cis_real = self .freqs_cis .real
312- self .freqs_cis_imag = self .freqs_cis .imag
313+ if self .use_rope_real :
314+ self .freqs_cis_real = self .freqs_cis .real
315+ self .freqs_cis_imag = self .freqs_cis .imag
313316 if q .shape [- 2 ] != k .shape [- 2 ]:
314317 assert self .rope_k_repeat
315318
0 commit comments