Skip to content

Commit bb321e7

Browse files
committed
revert changes to embeddings, normalization, transformer
1 parent e26604c commit bb321e7

File tree

3 files changed

+144
-85
lines changed

3 files changed

+144
-85
lines changed

src/diffusers/models/embeddings.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,58 @@ def forward(
15981598
return objs
15991599

16001600

1601+
class AllegroCombinedTimestepSizeEmbeddings(nn.Module):
1602+
"""
1603+
For Allegro. TODO(aryan)
1604+
"""
1605+
1606+
def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False):
1607+
super().__init__()
1608+
1609+
self.outdim = size_emb_dim
1610+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1611+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
1612+
1613+
self.use_additional_conditions = use_additional_conditions
1614+
if use_additional_conditions:
1615+
self.use_additional_conditions = True
1616+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1617+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1618+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1619+
1620+
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
1621+
if size.ndim == 1:
1622+
size = size[:, None]
1623+
1624+
if size.shape[0] != batch_size:
1625+
size = size.repeat(batch_size // size.shape[0], 1)
1626+
if size.shape[0] != batch_size:
1627+
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
1628+
1629+
current_batch_size, dims = size.shape[0], size.shape[1]
1630+
size = size.reshape(-1)
1631+
size_freq = self.additional_condition_proj(size).to(size.dtype)
1632+
1633+
size_emb = embedder(size_freq)
1634+
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
1635+
return size_emb
1636+
1637+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
1638+
timesteps_proj = self.time_proj(timestep)
1639+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
1640+
1641+
if self.use_additional_conditions:
1642+
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
1643+
aspect_ratio = self.apply_condition(
1644+
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
1645+
)
1646+
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
1647+
else:
1648+
conditioning = timesteps_emb
1649+
1650+
return conditioning
1651+
1652+
16011653
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
16021654
"""
16031655
For PixArt-Alpha.

src/diffusers/models/normalization.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
from ..utils import is_torch_version
2424
from .activations import get_activation
25-
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
25+
from .embeddings import (
26+
AllegroCombinedTimestepSizeEmbeddings,
27+
CombinedTimestepLabelEmbeddings,
28+
PixArtAlphaCombinedTimestepSizeEmbeddings,
29+
)
2630

2731

2832
class AdaLayerNorm(nn.Module):
@@ -263,7 +267,6 @@ def forward(
263267
hidden_dtype: Optional[torch.dtype] = None,
264268
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
265269
# No modulation happening here.
266-
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
267270
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
268271
return self.linear(self.silu(embedded_timestep)), embedded_timestep
269272

@@ -387,6 +390,41 @@ def forward(
387390
return x
388391

389392

393+
class AllegroAdaLayerNormSingle(nn.Module):
394+
r"""
395+
Norm layer adaptive layer norm single (adaLN-single).
396+
397+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
398+
399+
Parameters:
400+
embedding_dim (`int`): The size of each embedding vector.
401+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
402+
"""
403+
404+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
405+
super().__init__()
406+
407+
self.emb = AllegroCombinedTimestepSizeEmbeddings(
408+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
409+
)
410+
411+
self.silu = nn.SiLU()
412+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
413+
414+
def forward(
415+
self,
416+
timestep: torch.Tensor,
417+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
418+
batch_size: int = None,
419+
hidden_dtype: Optional[torch.dtype] = None,
420+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
421+
# No modulation happening here.
422+
embedded_timestep = self.emb(
423+
timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
424+
)
425+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
426+
427+
390428
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
391429
r"""
392430
Norm layer adaptive layer norm zero (adaLN-Zero).

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 52 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
30-
from ..normalization import AdaLayerNormSingle
30+
from ..normalization import AllegroAdaLayerNormSingle
3131

3232

3333
logger = logging.get_logger(__name__)
@@ -36,30 +36,7 @@
3636
@maybe_allow_in_graph
3737
class AllegroTransformerBlock(nn.Module):
3838
r"""
39-
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
40-
41-
Args:
42-
dim (`int`):
43-
The number of channels in the input and output.
44-
num_attention_heads (`int`):
45-
The number of heads to use for multi-head attention.
46-
attention_head_dim (`int`):
47-
The number of channels in each head.
48-
dropout (`float`, defaults to `0.0`):
49-
The dropout probability to use.
50-
cross_attention_dim (`int`, defaults to `2304`):
51-
The dimension of the cross attention features.
52-
activation_fn (`str`, defaults to `"gelu-approximate"`):
53-
Activation function to be used in feed-forward.
54-
attention_bias (`bool`, defaults to `False`):
55-
Whether or not to use bias in attention projection layers.
56-
only_cross_attention (`bool`, defaults to `False`):
57-
norm_elementwise_affine (`bool`, defaults to `True`):
58-
Whether to use learnable elementwise affine parameters for normalization.
59-
norm_eps (`float`, defaults to `1e-5`):
60-
Epsilon value for normalization layers.
61-
final_dropout (`bool` defaults to `False`):
62-
Whether to apply a final dropout after the last feed-forward layer.
39+
TODO(aryan): docs
6340
"""
6441

6542
def __init__(
@@ -71,8 +48,11 @@ def __init__(
7148
cross_attention_dim: Optional[int] = None,
7249
activation_fn: str = "geglu",
7350
attention_bias: bool = False,
51+
only_cross_attention: bool = False,
52+
upcast_attention: bool = False,
7453
norm_elementwise_affine: bool = True,
7554
norm_eps: float = 1e-5,
55+
final_dropout: bool = False,
7656
):
7757
super().__init__()
7858

@@ -85,7 +65,8 @@ def __init__(
8565
dim_head=attention_head_dim,
8666
dropout=dropout,
8767
bias=attention_bias,
88-
cross_attention_dim=cross_attention_dim,
68+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
69+
upcast_attention=upcast_attention,
8970
processor=AllegroAttnProcessor2_0(),
9071
)
9172

@@ -98,6 +79,7 @@ def __init__(
9879
dim_head=attention_head_dim,
9980
dropout=dropout,
10081
bias=attention_bias,
82+
upcast_attention=upcast_attention,
10183
processor=AllegroAttnProcessor2_0(),
10284
) # is self-attn if encoder_hidden_states is none
10385

@@ -108,6 +90,7 @@ def __init__(
10890
dim,
10991
dropout=dropout,
11092
activation_fn=activation_fn,
93+
final_dropout=final_dropout,
11194
)
11295

11396
# 4. Scale-shift
@@ -164,63 +147,49 @@ def forward(
164147
ff_output = gate_mlp * ff_output
165148

166149
hidden_states = ff_output + hidden_states
150+
151+
# TODO(aryan): maybe following line is not required
152+
if hidden_states.ndim == 4:
153+
hidden_states = hidden_states.squeeze(1)
154+
167155
return hidden_states
168156

169157

170158
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
171159
_supports_gradient_checkpointing = True
172160

173-
r"""
174-
A 3D Transformer model for video-like data.
175-
176-
Args:
177-
patch_size (`int`, defaults to `2`):
178-
The size of spatial patches to use in the patch embedding layer.
179-
patch_size_t (`int`, defaults to `1`):
180-
The size of temporal patches to use in the patch embedding layer.
181-
num_attention_heads (`int`, defaults to `24`):
182-
The number of heads to use for multi-head attention.
183-
attention_head_dim (`int`, defaults to `96`):
184-
The number of channels in each head.
185-
in_channels (`int`, defaults to `4`):
186-
The number of channels in the input.
187-
out_channels (`int`, *optional*, defaults to `4`):
188-
The number of channels in the output.
189-
num_layers (`int`, defaults to `32`):
190-
The number of layers of Transformer blocks to use.
191-
dropout (`float`, defaults to `0.0`):
192-
The dropout probability to use.
193-
cross_attention_dim (`int`, defaults to `2304`):
194-
The dimension of the cross attention features.
195-
attention_bias (`bool`, defaults to `True`):
196-
Whether or not to use bias in the attention projection layers.
197-
sample_height (`int`, defaults to `90`):
198-
The height of the input latents.
199-
sample_width (`int`, defaults to `160`):
200-
The width of the input latents.
201-
sample_frames (`int`, defaults to `22`):
202-
The number of frames in the input latents.
203-
activation_fn (`str`, defaults to `"gelu-approximate"`):
204-
Activation function to use in feed-forward.
205-
norm_elementwise_affine (`bool`, defaults to `True`):
206-
Whether or not to use elementwise affine in normalization layers.
207-
norm_eps (`float`, defaults to `1e-5`):
208-
The epsilon value to use in normalization layers.
209-
caption_channels (`int`, defaults to `4096`):
210-
Number of channels to use for projecting the caption embeddings.
211-
interpolation_scale_h (`float`, defaults to `2.0`):
212-
Scaling factor to apply in 3D positional embeddings across height dimension.
213-
interpolation_scale_w (`float`, defaults to `2.0`):
214-
Scaling factor to apply in 3D positional embeddings across width dimension.
215-
interpolation_scale_t (`float`, defaults to `2.2`):
216-
Scaling factor to apply in 3D positional embeddings across time dimension.
161+
"""
162+
A 2D Transformer model for image-like data.
163+
164+
Parameters:
165+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
166+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
167+
in_channels (`int`, *optional*):
168+
The number of channels in the input and output (specify if the input is **continuous**).
169+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
170+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
171+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
172+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
173+
This is fixed during training since it is used to learn a number of position embeddings.
174+
num_vector_embeds (`int`, *optional*):
175+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
176+
Includes the class for the masked latent pixel.
177+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
178+
num_embeds_ada_norm ( `int`, *optional*):
179+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
180+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
181+
added to the hidden states.
182+
183+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
184+
attention_bias (`bool`, *optional*):
185+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
217186
"""
218187

219188
@register_to_config
220189
def __init__(
221190
self,
222191
patch_size: int = 2,
223-
patch_size_t: int = 1,
192+
patch_size_temporal: int = 1,
224193
num_attention_heads: int = 24,
225194
attention_head_dim: int = 96,
226195
in_channels: int = 4,
@@ -233,6 +202,7 @@ def __init__(
233202
sample_width: int = 160,
234203
sample_frames: int = 22,
235204
activation_fn: str = "gelu-approximate",
205+
upcast_attention: bool = False,
236206
norm_elementwise_affine: bool = False,
237207
norm_eps: float = 1e-6,
238208
caption_channels: int = 4096,
@@ -275,6 +245,7 @@ def __init__(
275245
cross_attention_dim=cross_attention_dim,
276246
activation_fn=activation_fn,
277247
attention_bias=attention_bias,
248+
upcast_attention=upcast_attention,
278249
norm_elementwise_affine=norm_elementwise_affine,
279250
norm_eps=norm_eps,
280251
)
@@ -288,7 +259,7 @@ def __init__(
288259
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
289260

290261
# 4. Timestep embeddings
291-
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
262+
self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
292263

293264
# 5. Caption projection
294265
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
@@ -309,13 +280,9 @@ def forward(
309280
return_dict: bool = True,
310281
):
311282
batch_size, num_channels, num_frames, height, width = hidden_states.shape
312-
p_t = self.config.patch_size_t
283+
p_t = self.config.patch_size_temporal
313284
p = self.config.patch_size
314285

315-
post_patch_num_frames = num_frames // self.config.patch_size_t
316-
post_patch_height = height // self.config.patch_size
317-
post_patch_width = width // self.config.patch_size
318-
319286
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
320287
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
321288
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -350,20 +317,22 @@ def forward(
350317
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
351318
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
352319

353-
# 1. Timestep embeddings
320+
# 1. Input
321+
post_patch_num_frames = num_frames // self.config.patch_size_temporal
322+
post_patch_height = height // self.config.patch_size
323+
post_patch_width = width // self.config.patch_size
324+
354325
timestep, embedded_timestep = self.adaln_single(
355326
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
356327
)
357328

358-
# 2. Patch embeddings
359329
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
360330
hidden_states = self.pos_embed(hidden_states)
361331
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
362332

363333
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
364334
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
365335

366-
# 3. Transformer blocks
367336
for i, block in enumerate(self.transformer_blocks):
368337
# TODO(aryan): Implement gradient checkpointing
369338
if self.gradient_checkpointing:
@@ -395,16 +364,16 @@ def custom_forward(*inputs):
395364
image_rotary_emb=image_rotary_emb,
396365
)
397366

398-
# 4. Output normalization & projection
367+
# 3. Output
399368
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
400369
hidden_states = self.norm_out(hidden_states)
401370

402-
# modulation
371+
# Modulation
403372
hidden_states = hidden_states * (1 + scale) + shift
404373
hidden_states = self.proj_out(hidden_states)
405374
hidden_states = hidden_states.squeeze(1)
406375

407-
# 5. Unpatchify
376+
# unpatchify
408377
hidden_states = hidden_states.reshape(
409378
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
410379
)

0 commit comments

Comments
 (0)