@@ -288,6 +288,91 @@ def forward(self, latent):
288288 return (latent + pos_embed ).to (latent .dtype )
289289
290290
291+
292+ class OmniGenPatchEmbed (nn .Module ):
293+ """2D Image to Patch Embedding with support for OmniGen."""
294+
295+ def __init__ (
296+ self ,
297+ patch_size : int = 2 ,
298+ in_channels : int = 4 ,
299+ embed_dim : int = 768 ,
300+ bias : bool = True ,
301+ interpolation_scale : float = 1 ,
302+ pos_embed_max_size : int = 192 ,
303+ base_size : int = 64 ,
304+ ):
305+ super ().__init__ ()
306+
307+ self .output_image_proj = nn .Conv2d (
308+ in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias = bias
309+ )
310+ self .input_image_proj = nn .Conv2d (
311+ in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = patch_size , bias = bias
312+ )
313+
314+ self .patch_size = patch_size
315+ self .interpolation_scale = interpolation_scale
316+ self .pos_embed_max_size = pos_embed_max_size
317+
318+ pos_embed = get_2d_sincos_pos_embed (
319+ embed_dim , self .pos_embed_max_size , base_size = base_size , interpolation_scale = self .interpolation_scale
320+ )
321+ self .register_buffer ("pos_embed" , torch .from_numpy (pos_embed ).float ().unsqueeze (0 ), persistent = True )
322+
323+ def cropped_pos_embed (self , height , width ):
324+ """Crops positional embeddings for SD3 compatibility."""
325+ if self .pos_embed_max_size is None :
326+ raise ValueError ("`pos_embed_max_size` must be set for cropping." )
327+
328+ height = height // self .patch_size
329+ width = width // self .patch_size
330+ if height > self .pos_embed_max_size :
331+ raise ValueError (
332+ f"Height ({ height } ) cannot be greater than `pos_embed_max_size`: { self .pos_embed_max_size } ."
333+ )
334+ if width > self .pos_embed_max_size :
335+ raise ValueError (
336+ f"Width ({ width } ) cannot be greater than `pos_embed_max_size`: { self .pos_embed_max_size } ."
337+ )
338+
339+ top = (self .pos_embed_max_size - height ) // 2
340+ left = (self .pos_embed_max_size - width ) // 2
341+ spatial_pos_embed = self .pos_embed .reshape (1 , self .pos_embed_max_size , self .pos_embed_max_size , - 1 )
342+ spatial_pos_embed = spatial_pos_embed [:, top : top + height , left : left + width , :]
343+ spatial_pos_embed = spatial_pos_embed .reshape (1 , - 1 , spatial_pos_embed .shape [- 1 ])
344+ return spatial_pos_embed
345+
346+ def patch_embeddings (self , latent , is_input_image : bool ):
347+ if is_input_image :
348+ latent = self .input_image_proj (latent )
349+ else :
350+ latent = self .output_image_proj (latent )
351+ latent = latent .flatten (2 ).transpose (1 , 2 )
352+ return latent
353+
354+ def forward (self , latent , is_input_image : bool , padding_latent = None ):
355+ if isinstance (latent , list ):
356+ if padding_latent is None :
357+ padding_latent = [None ] * len (latent )
358+ patched_latents , num_tokens , shapes = [], [], []
359+ for sub_latent , padding in zip (latent , padding_latent ):
360+ height , width = sub_latent .shape [- 2 :]
361+ sub_latent = self .patch_embeddings (sub_latent , is_input_image )
362+ pos_embed = self .cropped_pos_embed (height , width )
363+ sub_latent = sub_latent + pos_embed
364+ if padding is not None :
365+ sub_latent = torch .cat ([sub_latent , padding ], dim = - 2 )
366+ patched_latents .append (sub_latent )
367+ else :
368+ height , width = latent .shape [- 2 :]
369+ pos_embed = self .cropped_pos_embed (height , width )
370+ latent = self .patch_embeddings (latent , is_input_image )
371+ latent = latent + pos_embed
372+
373+ return latent
374+
375+
291376class LuminaPatchEmbed (nn .Module ):
292377 """2D Image to Patch Embedding with support for Lumina-T2X"""
293378
@@ -935,6 +1020,48 @@ def forward(self, timesteps):
9351020 return t_emb
9361021
9371022
1023+ class OmniGenTimestepEmbed (nn .Module ):
1024+ """
1025+ Embeds scalar timesteps into vector representations for OmniGen
1026+ """
1027+
1028+ def __init__ (self , hidden_size , frequency_embedding_size = 256 ):
1029+ super ().__init__ ()
1030+ self .mlp = nn .Sequential (
1031+ nn .Linear (frequency_embedding_size , hidden_size , bias = True ),
1032+ nn .SiLU (),
1033+ nn .Linear (hidden_size , hidden_size , bias = True ),
1034+ )
1035+ self .frequency_embedding_size = frequency_embedding_size
1036+
1037+ @staticmethod
1038+ def timestep_embedding (t , dim , max_period = 10000 ):
1039+ """
1040+ Create sinusoidal timestep embeddings.
1041+ :param t: a 1-D Tensor of N indices, one per batch element.
1042+ These may be fractional.
1043+ :param dim: the dimension of the output.
1044+ :param max_period: controls the minimum frequency of the embeddings.
1045+ :return: an (N, D) Tensor of positional embeddings.
1046+ """
1047+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
1048+ half = dim // 2
1049+ freqs = torch .exp (
1050+ - math .log (max_period ) * torch .arange (start = 0 , end = half , dtype = torch .float32 ) / half
1051+ ).to (device = t .device )
1052+ args = t [:, None ].float () * freqs [None ]
1053+ embedding = torch .cat ([torch .cos (args ), torch .sin (args )], dim = - 1 )
1054+ if dim % 2 :
1055+ embedding = torch .cat ([embedding , torch .zeros_like (embedding [:, :1 ])], dim = - 1 )
1056+ return embedding
1057+
1058+ def forward (self , t , dtype = torch .float32 ):
1059+ t_freq = self .timestep_embedding (t , self .frequency_embedding_size ).to (dtype )
1060+ t_emb = self .mlp (t_freq )
1061+ return t_emb
1062+
1063+
1064+
9381065class GaussianFourierProjection (nn .Module ):
9391066 """Gaussian Fourier embeddings for noise levels."""
9401067
0 commit comments