@@ -241,20 +241,22 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
241241
242242 def _precompute_freqs_cis (self , axes_dim : List [int ], axes_lens : List [int ], theta : int ) -> List [torch .Tensor ]:
243243 freqs_cis = []
244- # Use float32 for MPS compatibility
245- dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
246244 for i , (d , e ) in enumerate (zip (axes_dim , axes_lens )):
247- emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = dtype )
245+ emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = torch . float64 )
248246 freqs_cis .append (emb )
249247 return freqs_cis
250248
251249 def _get_freqs_cis (self , ids : torch .Tensor ) -> torch .Tensor :
250+ device = ids .device
251+ if ids .device .type == "mps" :
252+ ids = ids .to ("cpu" )
253+
252254 result = []
253255 for i in range (len (self .axes_dim )):
254256 freqs = self .freqs_cis [i ].to (ids .device )
255257 index = ids [:, :, i : i + 1 ].repeat (1 , 1 , freqs .shape [- 1 ]).to (torch .int64 )
256258 result .append (torch .gather (freqs .unsqueeze (0 ).repeat (index .shape [0 ], 1 , 1 ), dim = 1 , index = index ))
257- return torch .cat (result , dim = - 1 )
259+ return torch .cat (result , dim = - 1 ). to ( device )
258260
259261 def forward (self , hidden_states : torch .Tensor , attention_mask : torch .Tensor ):
260262 batch_size = len (hidden_states )
0 commit comments