@@ -252,20 +252,18 @@ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], th
252252        w_inv_freq  =  1.0  /  (theta  **  (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w  //  2 )].float () /  dim_w ))
253253        h_seq  =  torch .arange (self .rope_axes_dim [0 ])
254254        w_seq  =  torch .arange (self .rope_axes_dim [1 ])
255-         self .freqs_h  =  torch .outer (h_seq , h_inv_freq )
256-         self .freqs_w  =  torch .outer (w_seq , w_inv_freq )
255+         self .freqs_h  =  torch .nn . Buffer ( torch . outer (h_seq , h_inv_freq ) )
256+         self .freqs_w  =  torch .nn . Buffer ( torch . outer (w_seq , w_inv_freq ) )
257257
258258    def  forward (self , hidden_states : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
259259        batch_size , num_channels , height , width  =  hidden_states .shape 
260260        height , width  =  height  //  self .patch_size , width  //  self .patch_size 
261261
262-         h_idx  =  torch .arange (height )
263-         w_idx  =  torch .arange (width )
262+         h_idx  =  torch .arange (height ,  device = self . freqs_h . device )
263+         w_idx  =  torch .arange (width ,  device = self . freqs_w . device )
264264        inner_h_idx  =  h_idx  *  self .rope_axes_dim [0 ] //  height 
265265        inner_w_idx  =  w_idx  *  self .rope_axes_dim [1 ] //  width 
266266
267-         self .freqs_h  =  self .freqs_h .to (hidden_states .device )
268-         self .freqs_w  =  self .freqs_w .to (hidden_states .device )
269267        freqs_h  =  self .freqs_h [inner_h_idx ]
270268        freqs_w  =  self .freqs_w [inner_w_idx ]
271269
0 commit comments