@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module):
244244 def __init__ (self , dim : int , patch_size : int , rope_axes_dim : Tuple [int , int ], theta : float = 10000.0 ) -> None :
245245 super ().__init__ ()
246246
247+ self .dim = dim
247248 self .patch_size = patch_size
248249 self .rope_axes_dim = rope_axes_dim
249-
250- dim_h , dim_w = dim // 2 , dim // 2
251- h_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h ))
252- w_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w ))
253- h_seq = torch .arange (self .rope_axes_dim [0 ])
254- w_seq = torch .arange (self .rope_axes_dim [1 ])
255- self .freqs_h = torch .outer (h_seq , h_inv_freq )
256- self .freqs_w = torch .outer (w_seq , w_inv_freq )
250+ self .theta = theta
257251
258252 def forward (self , hidden_states : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
259253 batch_size , num_channels , height , width = hidden_states .shape
260254 height , width = height // self .patch_size , width // self .patch_size
261255
262- h_idx = torch .arange (height )
263- w_idx = torch .arange (width )
256+ dim_h , dim_w = self .dim // 2 , self .dim // 2
257+ h_inv_freq = 1.0 / (
258+ self .theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h )
259+ )
260+ w_inv_freq = 1.0 / (
261+ self .theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w )
262+ )
263+ h_seq = torch .arange (self .rope_axes_dim [0 ])
264+ w_seq = torch .arange (self .rope_axes_dim [1 ])
265+ freqs_h = torch .outer (h_seq , h_inv_freq )
266+ freqs_w = torch .outer (w_seq , w_inv_freq )
267+
268+ h_idx = torch .arange (height , device = freqs_h .device )
269+ w_idx = torch .arange (width , device = freqs_w .device )
264270 inner_h_idx = h_idx * self .rope_axes_dim [0 ] // height
265271 inner_w_idx = w_idx * self .rope_axes_dim [1 ] // width
266272
267- self .freqs_h = self .freqs_h .to (hidden_states .device )
268- self .freqs_w = self .freqs_w .to (hidden_states .device )
269- freqs_h = self .freqs_h [inner_h_idx ]
270- freqs_w = self .freqs_w [inner_w_idx ]
273+ freqs_h = freqs_h [inner_h_idx ]
274+ freqs_w = freqs_w [inner_w_idx ]
271275
272276 # Create position matrices for height and width
273277 # [height, 1, dim//4] and [1, width, dim//4]
0 commit comments