@@ -244,28 +244,30 @@ class CogView4RotaryPosEmbed(nn.Module):
244244    def  __init__ (self , dim : int , patch_size : int , rope_axes_dim : Tuple [int , int ], theta : float  =  10000.0 ) ->  None :
245245        super ().__init__ ()
246246
247+         self .dim  =  dim 
247248        self .patch_size  =  patch_size 
248249        self .rope_axes_dim  =  rope_axes_dim 
249- 
250-         dim_h , dim_w  =  dim  //  2 , dim  //  2 
251-         h_inv_freq  =  1.0  /  (theta  **  (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h  //  2 )].float () /  dim_h ))
252-         w_inv_freq  =  1.0  /  (theta  **  (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w  //  2 )].float () /  dim_w ))
253-         h_seq  =  torch .arange (self .rope_axes_dim [0 ])
254-         w_seq  =  torch .arange (self .rope_axes_dim [1 ])
255-         self .freqs_h  =  self .register_buffer ("freqs_h" , torch .outer (h_seq , h_inv_freq ), persistent = False )
256-         self .freqs_w  =  self .register_buffer ("freqs_h" , torch .outer (w_seq , w_inv_freq ), persistent = False )
250+         self .theta  =  theta 
257251
258252    def  forward (self , hidden_states : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
259253        batch_size , num_channels , height , width  =  hidden_states .shape 
260254        height , width  =  height  //  self .patch_size , width  //  self .patch_size 
261255
262-         h_idx  =  torch .arange (height , device = self .freqs_h .device )
263-         w_idx  =  torch .arange (width , device = self .freqs_w .device )
256+         dim_h , dim_w  =  self .dim  //  2 , self .dim  //  2 
257+         h_inv_freq  =  1.0  /  (self .theta  **  (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h  //  2 )].float () /  dim_h ))
258+         w_inv_freq  =  1.0  /  (self .theta  **  (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w  //  2 )].float () /  dim_w ))
259+         h_seq  =  torch .arange (self .rope_axes_dim [0 ])
260+         w_seq  =  torch .arange (self .rope_axes_dim [1 ])
261+         freqs_h  =  torch .outer (h_seq , h_inv_freq )
262+         freqs_w  =  torch .outer (w_seq , w_inv_freq )
263+ 
264+         h_idx  =  torch .arange (height , device = freqs_h .device )
265+         w_idx  =  torch .arange (width , device = freqs_w .device )
264266        inner_h_idx  =  h_idx  *  self .rope_axes_dim [0 ] //  height 
265267        inner_w_idx  =  w_idx  *  self .rope_axes_dim [1 ] //  width 
266268
267-         freqs_h  =  self . freqs_h [inner_h_idx ]
268-         freqs_w  =  self . freqs_w [inner_w_idx ]
269+         freqs_h  =  freqs_h [inner_h_idx ]
270+         freqs_w  =  freqs_w [inner_w_idx ]
269271
270272        # Create position matrices for height and width 
271273        # [height, 1, dim//4] and [1, width, dim//4] 
0 commit comments