Skip to content

Commit 8f9ffa8

Browse files
committed
refactor part 5
1 parent bcba858 commit 8f9ffa8

File tree

3 files changed

+20
-100
lines changed

3 files changed

+20
-100
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
import torch.nn as nn
21-
from einops import rearrange
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ..attention_processor import Attention, SpatialNorm
@@ -114,7 +113,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
114113
hidden_states = self.conv1(hidden_states)
115114

116115
if self.up_sample:
117-
hidden_states = rearrange(hidden_states, "b (d c) f h w -> b c (f d) h w", d=2)
116+
hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
118117

119118
hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
120119
hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
@@ -858,10 +857,10 @@ def encode(
858857
)
859858
out_video_cube[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean_blend
860859

861-
## final conv
862-
out_video_cube = rearrange(out_video_cube, "b c n h w -> (b n) c h w")
860+
# final conv
861+
out_video_cube = out_video_cube.permute(0, 2, 1, 3, 4).flatten(0, 1)
863862
out_video_cube = self.quant_conv(out_video_cube)
864-
out_video_cube = rearrange(out_video_cube, "(b n) c h w -> b c n h w", b=B)
863+
out_video_cube = out_video_cube.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4)
865864

866865
posterior = DiagonalGaussianDistribution(out_video_cube)
867866

@@ -885,9 +884,9 @@ def decode(
885884
B, C, N, H, W = input_latents.shape
886885

887886
## post quant conv (a mapping)
888-
input_latents = rearrange(input_latents, "b c n h w -> (b n) c h w")
887+
input_latents = input_latents.permute(0, 2, 1, 3, 4).flatten(0, 1)
889888
input_latents = self.post_quant_conv(input_latents)
890-
input_latents = rearrange(input_latents, "(b n) c h w -> b c n h w", b=B)
889+
input_latents = input_latents.unflatten(0, (B, -1)).permute(0, 2, 1, 3, 4)
891890

892891
## out tensor shape
893892
out_n = math.floor((N - IN_KERNEL[0]) / IN_STRIDE[0]) + 1
@@ -947,7 +946,7 @@ def decode(
947946
)
948947
out_video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
949948

950-
out_video = rearrange(out_video, "b c t h w -> b t c h w").contiguous()
949+
out_video = out_video.permute(0, 2, 1, 3, 4).contiguous()
951950

952951
decoded = out_video
953952
if not return_dict:

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 11 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,13 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from einops import rearrange
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ...utils import logging
2524
from ...utils.torch_utils import maybe_allow_in_graph
2625
from ..attention import FeedForward
27-
from ..attention_processor import (
28-
AllegroAttnProcessor2_0,
29-
Attention,
30-
)
31-
from ..embeddings import PixArtAlphaTextProjection
26+
from ..attention_processor import AllegroAttnProcessor2_0, Attention
27+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
3228
from ..modeling_outputs import Transformer2DModelOutput
3329
from ..modeling_utils import ModelMixin
3430
from ..normalization import AllegroAdaLayerNormSingle
@@ -37,57 +33,6 @@
3733
logger = logging.get_logger(__name__)
3834

3935

40-
class PatchEmbed2D(nn.Module):
41-
"""2D Image to Patch Embedding"""
42-
43-
def __init__(
44-
self,
45-
num_frames=1,
46-
height=224,
47-
width=224,
48-
patch_size_t=1,
49-
patch_size=16,
50-
in_channels=3,
51-
embed_dim=768,
52-
layer_norm=False,
53-
flatten=True,
54-
bias=True,
55-
use_abs_pos=False,
56-
):
57-
super().__init__()
58-
self.use_abs_pos = use_abs_pos
59-
self.flatten = flatten
60-
self.layer_norm = layer_norm
61-
62-
self.proj = nn.Conv2d(
63-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias
64-
)
65-
if layer_norm:
66-
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
67-
else:
68-
self.norm = None
69-
70-
self.patch_size_t = patch_size_t
71-
self.patch_size = patch_size
72-
73-
def forward(self, latent):
74-
b, _, _, _, _ = latent.shape
75-
video_latent = None
76-
77-
latent = rearrange(latent, "b c t h w -> (b t) c h w")
78-
79-
latent = self.proj(latent)
80-
if self.flatten:
81-
latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C
82-
if self.layer_norm:
83-
latent = self.norm(latent)
84-
85-
latent = rearrange(latent, "(b t) n c -> b (t n) c", b=b)
86-
video_latent = latent
87-
88-
return video_latent
89-
90-
9136
@maybe_allow_in_graph
9237
class AllegroTransformerBlock(nn.Module):
9338
r"""
@@ -280,13 +225,13 @@ def __init__(
280225
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
281226

282227
# 1. Patch embedding
283-
self.pos_embed = PatchEmbed2D(
228+
self.pos_embed = PatchEmbed(
284229
height=sample_height,
285230
width=sample_width,
286231
patch_size=patch_size,
287232
in_channels=in_channels,
288233
embed_dim=self.inner_dim,
289-
# pos_embed_type=None,
234+
pos_embed_type=None,
290235
)
291236

292237
# 2. Transformer blocks
@@ -327,8 +272,8 @@ def _set_gradient_checkpointing(self, module, value=False):
327272
def forward(
328273
self,
329274
hidden_states: torch.Tensor,
330-
encoder_hidden_states: Optional[torch.Tensor] = None,
331-
timestep: Optional[torch.LongTensor] = None,
275+
encoder_hidden_states: torch.Tensor,
276+
timestep: torch.LongTensor,
332277
attention_mask: Optional[torch.Tensor] = None,
333278
encoder_attention_mask: Optional[torch.Tensor] = None,
334279
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
@@ -368,13 +313,9 @@ def forward(
368313
)
369314

370315
# convert encoder_attention_mask to a bias the same way we do for attention_mask
371-
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
372-
# b, 1+use_image_num, l -> a video with images
373-
# b, 1, l -> only images
316+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
374317
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
375-
encoder_attention_mask = (
376-
rearrange(encoder_attention_mask, "b 1 l -> (b 1) 1 l") if encoder_attention_mask.numel() > 0 else None
377-
)
318+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
378319

379320
# 1. Input
380321
post_patch_num_frames = num_frames // self.config.patch_size_temporal
@@ -385,9 +326,9 @@ def forward(
385326
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
386327
)
387328

388-
hidden_states = self.pos_embed(
389-
hidden_states
390-
) # TODO(aryan): remove dtype conversion here and move to pipeline if needed
329+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
330+
hidden_states = self.pos_embed(hidden_states)
331+
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
391332

392333
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
393334
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -836,25 +836,11 @@ def __call__(
836836
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
837837
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
838838

839-
current_timestep = t
840-
if not torch.is_tensor(current_timestep):
841-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
842-
# This would be a good case for the `match` statement (Python 3.10+)
843-
is_mps = latent_model_input.device.type == "mps"
844-
if isinstance(current_timestep, float):
845-
dtype = torch.float32 if is_mps else torch.float64
846-
else:
847-
dtype = torch.int32 if is_mps else torch.int64
848-
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
849-
elif len(current_timestep.shape) == 0:
850-
current_timestep = current_timestep[None].to(latent_model_input.device)
851839
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
852-
current_timestep = current_timestep.expand(latent_model_input.shape[0])
840+
timestep = t.expand(latent_model_input.shape[0])
853841

854842
if prompt_embeds.ndim == 3:
855843
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
856-
if prompt_attention_mask.ndim == 2:
857-
prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l
858844

859845
# prepare attention_mask.
860846
# b c t h w -> b t h w
@@ -866,7 +852,7 @@ def __call__(
866852
attention_mask=attention_mask,
867853
encoder_hidden_states=prompt_embeds,
868854
encoder_attention_mask=prompt_attention_mask,
869-
timestep=current_timestep,
855+
timestep=timestep,
870856
image_rotary_emb=image_rotary_emb,
871857
return_dict=False,
872858
)[0]
@@ -876,12 +862,6 @@ def __call__(
876862
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
877863
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
878864

879-
# learned sigma
880-
if latent_channels == self.transformer.config.out_channels // 2:
881-
noise_pred = noise_pred.chunk(2, dim=1)[0]
882-
else:
883-
noise_pred = noise_pred
884-
885865
# compute previous image: x_t -> x_t-1
886866
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
887867

0 commit comments

Comments
 (0)