@@ -304,10 +304,11 @@ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
304304 def __call__ (self , ids : torch .Tensor ):
305305 assert ids .ndim == 2
306306 assert ids .shape [- 1 ] == len (self .axes_dims )
307+ device = ids .device
307308
308309 if self .freqs_cis is None :
309310 self .freqs_cis = self .precompute_freqs_cis (self .axes_dims , self .axes_lens , theta = self .theta )
310- self .freqs_cis = [freqs_cis .cuda ( ) for freqs_cis in self .freqs_cis ]
311+ self .freqs_cis = [freqs_cis .to ( device ) for freqs_cis in self .freqs_cis ]
311312
312313 result = []
313314 for i in range (len (self .axes_dims )):
@@ -596,6 +597,7 @@ def forward(
596597 x_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
597598 x_attn_mask [i , seq_len :] = 0
598599 x = torch .stack (x )
600+ x_freqs_cis = torch .stack (x_freqs_cis )
599601
600602 for layer in self .noise_refiner :
601603 x = layer (
@@ -638,6 +640,7 @@ def forward(
638640 cap_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
639641 cap_attn_mask [i , seq_len :] = 0
640642 cap_feats = torch .stack (cap_feats )
643+ cap_freqs_cis = torch .stack (cap_freqs_cis )
641644
642645 for layer in self .context_refiner :
643646 cap_feats = layer (
@@ -680,6 +683,7 @@ def forward(
680683 unified_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
681684 unified_attn_mask [i , seq_len :] = 0
682685 unified = torch .stack (unified )
686+ unified_freqs_cis = torch .stack (unified_freqs_cis )
683687
684688 for layer in self .layers :
685689 unified = layer (
0 commit comments