Skip to content

Commit 174621f

Browse files
committed
refactor part 8
1 parent bb321e7 commit 174621f

File tree

3 files changed

+82
-135
lines changed

3 files changed

+82
-135
lines changed

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,58 +1598,6 @@ def forward(
15981598
return objs
15991599

16001600

1601-
class AllegroCombinedTimestepSizeEmbeddings(nn.Module):
1602-
"""
1603-
For Allegro. TODO(aryan)
1604-
"""
1605-
1606-
def __init__(self, embedding_dim: int, size_emb_dim: int, use_additional_conditions: bool = False):
1607-
super().__init__()
1608-
1609-
self.outdim = size_emb_dim
1610-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1611-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
1612-
1613-
self.use_additional_conditions = use_additional_conditions
1614-
if use_additional_conditions:
1615-
self.use_additional_conditions = True
1616-
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1617-
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1618-
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
1619-
1620-
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
1621-
if size.ndim == 1:
1622-
size = size[:, None]
1623-
1624-
if size.shape[0] != batch_size:
1625-
size = size.repeat(batch_size // size.shape[0], 1)
1626-
if size.shape[0] != batch_size:
1627-
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
1628-
1629-
current_batch_size, dims = size.shape[0], size.shape[1]
1630-
size = size.reshape(-1)
1631-
size_freq = self.additional_condition_proj(size).to(size.dtype)
1632-
1633-
size_emb = embedder(size_freq)
1634-
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
1635-
return size_emb
1636-
1637-
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
1638-
timesteps_proj = self.time_proj(timestep)
1639-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
1640-
1641-
if self.use_additional_conditions:
1642-
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
1643-
aspect_ratio = self.apply_condition(
1644-
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
1645-
)
1646-
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
1647-
else:
1648-
conditioning = timesteps_emb
1649-
1650-
return conditioning
1651-
1652-
16531601
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
16541602
"""
16551603
For PixArt-Alpha.

src/diffusers/models/normalization.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
from ..utils import is_torch_version
2424
from .activations import get_activation
2525
from .embeddings import (
26-
AllegroCombinedTimestepSizeEmbeddings,
2726
CombinedTimestepLabelEmbeddings,
28-
PixArtAlphaCombinedTimestepSizeEmbeddings,
27+
PixArtAlphaCombinedTimestepSizeEmbeddings
2928
)
3029

3130

@@ -267,6 +266,7 @@ def forward(
267266
hidden_dtype: Optional[torch.dtype] = None,
268267
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
269268
# No modulation happening here.
269+
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
270270
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
271271
return self.linear(self.silu(embedded_timestep)), embedded_timestep
272272

@@ -390,41 +390,6 @@ def forward(
390390
return x
391391

392392

393-
class AllegroAdaLayerNormSingle(nn.Module):
394-
r"""
395-
Norm layer adaptive layer norm single (adaLN-single).
396-
397-
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
398-
399-
Parameters:
400-
embedding_dim (`int`): The size of each embedding vector.
401-
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
402-
"""
403-
404-
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
405-
super().__init__()
406-
407-
self.emb = AllegroCombinedTimestepSizeEmbeddings(
408-
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
409-
)
410-
411-
self.silu = nn.SiLU()
412-
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
413-
414-
def forward(
415-
self,
416-
timestep: torch.Tensor,
417-
added_cond_kwargs: Dict[str, torch.Tensor] = None,
418-
batch_size: int = None,
419-
hidden_dtype: Optional[torch.dtype] = None,
420-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
421-
# No modulation happening here.
422-
embedded_timestep = self.emb(
423-
timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
424-
)
425-
return self.linear(self.silu(embedded_timestep)), embedded_timestep
426-
427-
428393
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
429394
r"""
430395
Norm layer adaptive layer norm zero (adaLN-Zero).

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
30-
from ..normalization import AllegroAdaLayerNormSingle
30+
from ..normalization import AdaLayerNormSingle
3131

3232

3333
logger = logging.get_logger(__name__)
@@ -36,7 +36,29 @@
3636
@maybe_allow_in_graph
3737
class AllegroTransformerBlock(nn.Module):
3838
r"""
39-
TODO(aryan): docs
39+
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
40+
Args:
41+
dim (`int`):
42+
The number of channels in the input and output.
43+
num_attention_heads (`int`):
44+
The number of heads to use for multi-head attention.
45+
attention_head_dim (`int`):
46+
The number of channels in each head.
47+
dropout (`float`, defaults to `0.0`):
48+
The dropout probability to use.
49+
cross_attention_dim (`int`, defaults to `2304`):
50+
The dimension of the cross attention features.
51+
activation_fn (`str`, defaults to `"gelu-approximate"`):
52+
Activation function to be used in feed-forward.
53+
attention_bias (`bool`, defaults to `False`):
54+
Whether or not to use bias in attention projection layers.
55+
only_cross_attention (`bool`, defaults to `False`):
56+
norm_elementwise_affine (`bool`, defaults to `True`):
57+
Whether to use learnable elementwise affine parameters for normalization.
58+
norm_eps (`float`, defaults to `1e-5`):
59+
Epsilon value for normalization layers.
60+
final_dropout (`bool` defaults to `False`):
61+
Whether to apply a final dropout after the last feed-forward layer.
4062
"""
4163

4264
def __init__(
@@ -48,11 +70,8 @@ def __init__(
4870
cross_attention_dim: Optional[int] = None,
4971
activation_fn: str = "geglu",
5072
attention_bias: bool = False,
51-
only_cross_attention: bool = False,
52-
upcast_attention: bool = False,
5373
norm_elementwise_affine: bool = True,
5474
norm_eps: float = 1e-5,
55-
final_dropout: bool = False,
5675
):
5776
super().__init__()
5877

@@ -65,8 +84,7 @@ def __init__(
6584
dim_head=attention_head_dim,
6685
dropout=dropout,
6786
bias=attention_bias,
68-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
69-
upcast_attention=upcast_attention,
87+
cross_attention_dim=None,
7088
processor=AllegroAttnProcessor2_0(),
7189
)
7290

@@ -79,9 +97,8 @@ def __init__(
7997
dim_head=attention_head_dim,
8098
dropout=dropout,
8199
bias=attention_bias,
82-
upcast_attention=upcast_attention,
83100
processor=AllegroAttnProcessor2_0(),
84-
) # is self-attn if encoder_hidden_states is none
101+
)
85102

86103
# 3. Feed Forward
87104
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
@@ -90,7 +107,6 @@ def __init__(
90107
dim,
91108
dropout=dropout,
92109
activation_fn=activation_fn,
93-
final_dropout=final_dropout,
94110
)
95111

96112
# 4. Scale-shift
@@ -159,37 +175,55 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
159175
_supports_gradient_checkpointing = True
160176

161177
"""
162-
A 2D Transformer model for image-like data.
163-
164-
Parameters:
165-
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
166-
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
167-
in_channels (`int`, *optional*):
168-
The number of channels in the input and output (specify if the input is **continuous**).
169-
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
170-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
171-
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
172-
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
173-
This is fixed during training since it is used to learn a number of position embeddings.
174-
num_vector_embeds (`int`, *optional*):
175-
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
176-
Includes the class for the masked latent pixel.
177-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
178-
num_embeds_ada_norm ( `int`, *optional*):
179-
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
180-
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
181-
added to the hidden states.
182-
183-
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
184-
attention_bias (`bool`, *optional*):
185-
Configure if the `TransformerBlocks` attention should contain a bias parameter.
178+
A 3D Transformer model for video-like data.
179+
Args:
180+
patch_size (`int`, defaults to `2`):
181+
The size of spatial patches to use in the patch embedding layer.
182+
patch_size_t (`int`, defaults to `1`):
183+
The size of temporal patches to use in the patch embedding layer.
184+
num_attention_heads (`int`, defaults to `24`):
185+
The number of heads to use for multi-head attention.
186+
attention_head_dim (`int`, defaults to `96`):
187+
The number of channels in each head.
188+
in_channels (`int`, defaults to `4`):
189+
The number of channels in the input.
190+
out_channels (`int`, *optional*, defaults to `4`):
191+
The number of channels in the output.
192+
num_layers (`int`, defaults to `32`):
193+
The number of layers of Transformer blocks to use.
194+
dropout (`float`, defaults to `0.0`):
195+
The dropout probability to use.
196+
cross_attention_dim (`int`, defaults to `2304`):
197+
The dimension of the cross attention features.
198+
attention_bias (`bool`, defaults to `True`):
199+
Whether or not to use bias in the attention projection layers.
200+
sample_height (`int`, defaults to `90`):
201+
The height of the input latents.
202+
sample_width (`int`, defaults to `160`):
203+
The width of the input latents.
204+
sample_frames (`int`, defaults to `22`):
205+
The number of frames in the input latents.
206+
activation_fn (`str`, defaults to `"gelu-approximate"`):
207+
Activation function to use in feed-forward.
208+
norm_elementwise_affine (`bool`, defaults to `True`):
209+
Whether or not to use elementwise affine in normalization layers.
210+
norm_eps (`float`, defaults to `1e-5`):
211+
The epsilon value to use in normalization layers.
212+
caption_channels (`int`, defaults to `4096`):
213+
Number of channels to use for projecting the caption embeddings.
214+
interpolation_scale_h (`float`, defaults to `2.0`):
215+
Scaling factor to apply in 3D positional embeddings across height dimension.
216+
interpolation_scale_w (`float`, defaults to `2.0`):
217+
Scaling factor to apply in 3D positional embeddings across width dimension.
218+
interpolation_scale_t (`float`, defaults to `2.2`):
219+
Scaling factor to apply in 3D positional embeddings across time dimension.
186220
"""
187221

188222
@register_to_config
189223
def __init__(
190224
self,
191225
patch_size: int = 2,
192-
patch_size_temporal: int = 1,
226+
patch_size_t: int = 1,
193227
num_attention_heads: int = 24,
194228
attention_head_dim: int = 96,
195229
in_channels: int = 4,
@@ -202,7 +236,6 @@ def __init__(
202236
sample_width: int = 160,
203237
sample_frames: int = 22,
204238
activation_fn: str = "gelu-approximate",
205-
upcast_attention: bool = False,
206239
norm_elementwise_affine: bool = False,
207240
norm_eps: float = 1e-6,
208241
caption_channels: int = 4096,
@@ -245,7 +278,6 @@ def __init__(
245278
cross_attention_dim=cross_attention_dim,
246279
activation_fn=activation_fn,
247280
attention_bias=attention_bias,
248-
upcast_attention=upcast_attention,
249281
norm_elementwise_affine=norm_elementwise_affine,
250282
norm_eps=norm_eps,
251283
)
@@ -259,7 +291,7 @@ def __init__(
259291
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
260292

261293
# 4. Timestep embeddings
262-
self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
294+
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
263295

264296
# 5. Caption projection
265297
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
@@ -280,9 +312,13 @@ def forward(
280312
return_dict: bool = True,
281313
):
282314
batch_size, num_channels, num_frames, height, width = hidden_states.shape
283-
p_t = self.config.patch_size_temporal
315+
p_t = self.config.patch_size_t
284316
p = self.config.patch_size
285317

318+
post_patch_num_frames = num_frames // self.config.patch_size_temporal
319+
post_patch_height = height // self.config.patch_size
320+
post_patch_width = width // self.config.patch_size
321+
286322
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
287323
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
288324
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -317,22 +353,20 @@ def forward(
317353
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
318354
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
319355

320-
# 1. Input
321-
post_patch_num_frames = num_frames // self.config.patch_size_temporal
322-
post_patch_height = height // self.config.patch_size
323-
post_patch_width = width // self.config.patch_size
324-
356+
# 1. Timestep embeddings
325357
timestep, embedded_timestep = self.adaln_single(
326358
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
327359
)
328360

361+
# 2. Patch embeddings
329362
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
330363
hidden_states = self.pos_embed(hidden_states)
331364
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
332365

333366
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
334367
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
335368

369+
# 3. Transformer blocks
336370
for i, block in enumerate(self.transformer_blocks):
337371
# TODO(aryan): Implement gradient checkpointing
338372
if self.gradient_checkpointing:
@@ -364,7 +398,7 @@ def custom_forward(*inputs):
364398
image_rotary_emb=image_rotary_emb,
365399
)
366400

367-
# 3. Output
401+
# 4. Output normalization & projection
368402
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
369403
hidden_states = self.norm_out(hidden_states)
370404

@@ -373,7 +407,7 @@ def custom_forward(*inputs):
373407
hidden_states = self.proj_out(hidden_states)
374408
hidden_states = hidden_states.squeeze(1)
375409

376-
# unpatchify
410+
# 5. Unpatchify
377411
hidden_states = hidden_states.reshape(
378412
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
379413
)

0 commit comments

Comments
 (0)