Skip to content

Commit 36eee40

Browse files
committed
OmniGen model.py
1 parent ad5ecd1 commit 36eee40

File tree

6 files changed

+1338
-1
lines changed

6 files changed

+1338
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,91 @@ def forward(self, latent):
288288
return (latent + pos_embed).to(latent.dtype)
289289

290290

291+
292+
class OmniGenPatchEmbed(nn.Module):
293+
"""2D Image to Patch Embedding with support for OmniGen."""
294+
295+
def __init__(
296+
self,
297+
patch_size: int =2,
298+
in_channels: int =4,
299+
embed_dim: int =768,
300+
bias: bool =True,
301+
interpolation_scale: float =1,
302+
pos_embed_max_size: int =192,
303+
base_size: int =64,
304+
):
305+
super().__init__()
306+
307+
self.output_image_proj = nn.Conv2d(
308+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
309+
)
310+
self.input_image_proj = nn.Conv2d(
311+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
312+
)
313+
314+
self.patch_size = patch_size
315+
self.interpolation_scale = interpolation_scale
316+
self.pos_embed_max_size = pos_embed_max_size
317+
318+
pos_embed = get_2d_sincos_pos_embed(
319+
embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale
320+
)
321+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
322+
323+
def cropped_pos_embed(self, height, width):
324+
"""Crops positional embeddings for SD3 compatibility."""
325+
if self.pos_embed_max_size is None:
326+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
327+
328+
height = height // self.patch_size
329+
width = width // self.patch_size
330+
if height > self.pos_embed_max_size:
331+
raise ValueError(
332+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
333+
)
334+
if width > self.pos_embed_max_size:
335+
raise ValueError(
336+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
337+
)
338+
339+
top = (self.pos_embed_max_size - height) // 2
340+
left = (self.pos_embed_max_size - width) // 2
341+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
342+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
343+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
344+
return spatial_pos_embed
345+
346+
def patch_embeddings(self, latent, is_input_image: bool):
347+
if is_input_image:
348+
latent = self.input_image_proj(latent)
349+
else:
350+
latent = self.output_image_proj(latent)
351+
latent = latent.flatten(2).transpose(1, 2)
352+
return latent
353+
354+
def forward(self, latent, is_input_image: bool, padding_latent=None):
355+
if isinstance(latent, list):
356+
if padding_latent is None:
357+
padding_latent = [None] * len(latent)
358+
patched_latents, num_tokens, shapes = [], [], []
359+
for sub_latent, padding in zip(latent, padding_latent):
360+
height, width = sub_latent.shape[-2:]
361+
sub_latent = self.patch_embeddings(sub_latent, is_input_image)
362+
pos_embed = self.cropped_pos_embed(height, width)
363+
sub_latent = sub_latent + pos_embed
364+
if padding is not None:
365+
sub_latent = torch.cat([sub_latent, padding], dim=-2)
366+
patched_latents.append(sub_latent)
367+
else:
368+
height, width = latent.shape[-2:]
369+
pos_embed = self.cropped_pos_embed(height, width)
370+
latent = self.patch_embeddings(latent, is_input_image)
371+
latent = latent + pos_embed
372+
373+
return latent
374+
375+
291376
class LuminaPatchEmbed(nn.Module):
292377
"""2D Image to Patch Embedding with support for Lumina-T2X"""
293378

@@ -935,6 +1020,48 @@ def forward(self, timesteps):
9351020
return t_emb
9361021

9371022

1023+
class OmniGenTimestepEmbed(nn.Module):
1024+
"""
1025+
Embeds scalar timesteps into vector representations for OmniGen
1026+
"""
1027+
1028+
def __init__(self, hidden_size, frequency_embedding_size=256):
1029+
super().__init__()
1030+
self.mlp = nn.Sequential(
1031+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
1032+
nn.SiLU(),
1033+
nn.Linear(hidden_size, hidden_size, bias=True),
1034+
)
1035+
self.frequency_embedding_size = frequency_embedding_size
1036+
1037+
@staticmethod
1038+
def timestep_embedding(t, dim, max_period=10000):
1039+
"""
1040+
Create sinusoidal timestep embeddings.
1041+
:param t: a 1-D Tensor of N indices, one per batch element.
1042+
These may be fractional.
1043+
:param dim: the dimension of the output.
1044+
:param max_period: controls the minimum frequency of the embeddings.
1045+
:return: an (N, D) Tensor of positional embeddings.
1046+
"""
1047+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
1048+
half = dim // 2
1049+
freqs = torch.exp(
1050+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
1051+
).to(device=t.device)
1052+
args = t[:, None].float() * freqs[None]
1053+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
1054+
if dim % 2:
1055+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
1056+
return embedding
1057+
1058+
def forward(self, t, dtype=torch.float32):
1059+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
1060+
t_emb = self.mlp(t_freq)
1061+
return t_emb
1062+
1063+
1064+
9381065
class GaussianFourierProjection(nn.Module):
9391066
"""Gaussian Fourier embeddings for noise levels."""
9401067

src/diffusers/models/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(
7171

7272
if self.chunk_dim == 1:
7373
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
74-
# other if-branch. This branch is specific to CogVideoX for now.
74+
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
7575
shift, scale = temb.chunk(2, dim=1)
7676
shift = shift[:, None, :]
7777
scale = scale[:, None, :]

0 commit comments

Comments
 (0)