@@ -2611,78 +2611,3 @@ def forward(self, image_embeds: List[torch.Tensor]):
26112611 projected_image_embeds .append (image_embed )
26122612
26132613 return projected_image_embeds
2614-
2615-
2616- class CogViewRotary2DEmbedding (nn .Module ):
2617- def __init__ (
2618- self ,
2619- kv_channels : int ,
2620- rotary_percent : float ,
2621- max_h : int = 128 ,
2622- max_w : int = 128 ,
2623- rotary_interleaved : bool = False ,
2624- seq_len_interpolation_factor : float = None ,
2625- inner_interp : bool = False ,
2626- rotary_base : int = 10000 ,
2627- ) -> None :
2628- super ().__init__ ()
2629-
2630- dim = kv_channels
2631- if rotary_percent < 1.0 :
2632- dim = int (dim * rotary_percent )
2633- self .rotary_interleaved = rotary_interleaved
2634-
2635- self .seq_len_interpolation_factor = seq_len_interpolation_factor
2636- self .inner_interp = inner_interp
2637-
2638- dim_h = kv_channels // 2
2639- dim_w = kv_channels // 2
2640-
2641- device = torch .cuda .current_device ()
2642- h_inv_freq = 1.0 / (
2643- rotary_base
2644- ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 , device = device )[: (dim_h // 2 )].float () / dim_h )
2645- )
2646- w_inv_freq = 1.0 / (
2647- rotary_base
2648- ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 , device = device )[: (dim_w // 2 )].float () / dim_w )
2649- )
2650-
2651- h_seq = torch .arange (max_h , device = device , dtype = h_inv_freq .dtype )
2652- w_seq = torch .arange (max_w , device = device , dtype = w_inv_freq .dtype )
2653-
2654- self .freqs_h = torch .outer (h_seq , h_inv_freq )
2655- self .freqs_w = torch .outer (w_seq , w_inv_freq )
2656- self .max_h = max_h
2657- self .max_w = max_w
2658-
2659- def forward (
2660- self ,
2661- h_idx : torch .Tensor ,
2662- w_idx : torch .Tensor ,
2663- target_h : torch .Tensor = None ,
2664- target_w : torch .Tensor = None ,
2665- mask : torch .Tensor = None ,
2666- ) -> torch .Tensor :
2667- if self .inner_interp :
2668- inner_h_idx = (h_idx * self .max_h ) // target_h
2669- inner_w_idx = (w_idx * self .max_w ) // target_w
2670-
2671- h_emb = self .freqs_h [inner_h_idx ]
2672- w_emb = self .freqs_w [inner_w_idx ]
2673-
2674- else :
2675- h_emb = self .freqs_h [h_idx ]
2676- w_emb = self .freqs_w [w_idx ]
2677-
2678- mask = (mask == 1 ).unsqueeze (- 1 )
2679-
2680- emb = torch .cat ([h_emb , w_emb ], dim = - 1 ) * mask
2681-
2682- assert emb .ndim == 2 , f"expected emb to have 2 dimensions, got { emb .ndim } "
2683- if not self .rotary_interleaved :
2684- emb = torch .repeat_interleave (emb , 2 , dim = 0 )
2685- else :
2686- emb = torch .repeat_interleave (emb , 2 , dim = 1 )
2687-
2688- return emb
0 commit comments