Skip to content

Commit ec05bbd

Browse files
committed
refactor part 2
1 parent 901d10e commit ec05bbd

File tree

2 files changed

+49
-107
lines changed

2 files changed

+49
-107
lines changed

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 47 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,20 @@ def __init__(
314314
interpolation_scale_h: float = 2.0,
315315
interpolation_scale_w: float = 2.0,
316316
interpolation_scale_t: float = 2.2,
317-
use_additional_conditions: Optional[bool] = None,
318317
use_rotary_positional_embeddings: bool = True,
319318
model_max_length: int = 300,
320319
):
321320
super().__init__()
322321

323322
self.inner_dim = num_attention_heads * attention_head_dim
324-
self.out_channels = in_channels if out_channels is None else out_channels
325323

326324
interpolation_scale_t = (
327325
interpolation_scale_t if interpolation_scale_t is not None else ((sample_frames - 1) // 16 + 1) if sample_frames % 2 == 1 else sample_frames // 16
328326
)
329327
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
330328
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
331329

330+
# 1. Patch embedding
332331
self.pos_embed = PatchEmbed2D(
333332
height=sample_height,
334333
width=sample_width,
@@ -337,9 +336,8 @@ def __init__(
337336
embed_dim=self.inner_dim,
338337
# pos_embed_type=None,
339338
)
340-
interpolation_scale_thw = (interpolation_scale_t, interpolation_scale_h, interpolation_scale_w)
341339

342-
# 3. Define transformers blocks, spatial attention
340+
# 2. Transformer blocks
343341
self.transformer_blocks = nn.ModuleList(
344342
[
345343
AllegroTransformerBlock(
@@ -358,19 +356,18 @@ def __init__(
358356
]
359357
)
360358

361-
# 4. Define output layers
359+
# 3. Output projection & norm
362360
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
363361
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
364-
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
362+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
365363

366-
# 5. PixArt-Alpha blocks.
364+
# 4. Timestep embeddings
367365
self.adaln_single = AllegroAdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
368366

369-
self.caption_projection = None
370-
if caption_channels is not None:
371-
self.caption_projection = PixArtAlphaTextProjection(
372-
in_features=caption_channels, hidden_size=self.inner_dim
373-
)
367+
# 5. Caption projection
368+
self.caption_projection = PixArtAlphaTextProjection(
369+
in_features=caption_channels, hidden_size=self.inner_dim
370+
)
374371

375372
self.gradient_checkpointing = False
376373

@@ -382,15 +379,14 @@ def forward(
382379
hidden_states: torch.Tensor,
383380
timestep: Optional[torch.LongTensor] = None,
384381
encoder_hidden_states: Optional[torch.Tensor] = None,
385-
added_cond_kwargs: Dict[str, torch.Tensor] = None,
386-
class_labels: Optional[torch.LongTensor] = None,
387-
cross_attention_kwargs: Dict[str, Any] = None,
388382
attention_mask: Optional[torch.Tensor] = None,
389383
encoder_attention_mask: Optional[torch.Tensor] = None,
390384
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
391385
return_dict: bool = True,
392386
):
393-
batch_size, c, frame, h, w = hidden_states.shape
387+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
388+
p_t = self.config.patch_size_temporal
389+
p = self.config.patch_size
394390

395391
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
396392
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
@@ -409,112 +405,61 @@ def forward(
409405
# (keep = +0, discard = -10000.0)
410406
# b, frame+use_image_num, h, w -> a video with images
411407
# b, 1, h, w -> only images
412-
attention_mask = attention_mask.to(self.dtype)
413-
attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w
408+
attention_mask = attention_mask.to(hidden_states.dtype)
409+
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
414410

415-
if attention_mask_vid.numel() > 0:
416-
attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w
417-
attention_mask_vid = F.max_pool3d(attention_mask_vid, kernel_size=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size), stride=(self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size))
418-
attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)')
411+
if attention_mask.numel() > 0:
412+
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
413+
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
414+
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
419415

420-
attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None
416+
attention_mask = (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
421417

422418
# convert encoder_attention_mask to a bias the same way we do for attention_mask
423419
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3:
424420
# b, 1+use_image_num, l -> a video with images
425421
# b, 1, l -> only images
426422
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
427-
encoder_attention_mask_vid = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None
423+
encoder_attention_mask = rearrange(encoder_attention_mask, 'b 1 l -> (b 1) 1 l') if encoder_attention_mask.numel() > 0 else None
428424

429425
# 1. Input
430-
frame = frame // self.config.patch_size_temporal
431-
height = hidden_states.shape[-2] // self.config.patch_size
432-
width = hidden_states.shape[-1] // self.config.patch_size
426+
post_patch_num_frames = num_frames // self.config.patch_size_temporal
427+
post_patch_height = height // self.config.patch_size
428+
post_patch_width = width // self.config.patch_size
433429

434-
added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if added_cond_kwargs is None else added_cond_kwargs
435-
hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid = self._operate_on_patched_inputs(
436-
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size,
437-
)
430+
timestep, embedded_timestep = self.adaln_single(timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype)
431+
432+
hidden_states = self.pos_embed(hidden_states) # TODO(aryan): remove dtype conversion here and move to pipeline if needed
433+
434+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
435+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
438436

439-
for _, block in enumerate(self.transformer_blocks):
437+
for i, block in enumerate(self.transformer_blocks):
440438
# TODO(aryan): Implement gradient checkpointing
441-
block: AllegroTransformerBlock
442-
hidden_states = block.forward(
439+
hidden_states = block(
443440
hidden_states=hidden_states,
444-
encoder_hidden_states=encoder_hidden_states_vid,
445-
temb=timestep_vid,
446-
attention_mask=attention_mask_vid,
447-
encoder_attention_mask=encoder_attention_mask_vid,
441+
encoder_hidden_states=encoder_hidden_states,
442+
temb=timestep,
443+
attention_mask=attention_mask,
444+
encoder_attention_mask=encoder_attention_mask,
448445
image_rotary_emb=image_rotary_emb,
449446
)
450447

451448
# 3. Output
452-
output = None
453-
if hidden_states is not None:
454-
output = self._get_output_for_patched_inputs(
455-
hidden_states=hidden_states,
456-
timestep=timestep_vid,
457-
class_labels=class_labels,
458-
embedded_timestep=embedded_timestep_vid,
459-
num_frames=frame,
460-
height=height,
461-
width=width,
462-
) # b c t h w
449+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
450+
hidden_states = self.norm_out(hidden_states)
451+
452+
# Modulation
453+
hidden_states = hidden_states * (1 + scale) + shift
454+
hidden_states = self.proj_out(hidden_states)
455+
hidden_states = hidden_states.squeeze(1)
456+
457+
# unpatchify
458+
hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1)
459+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
460+
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
463461

464462
if not return_dict:
465463
return (output,)
466464

467465
return Transformer2DModelOutput(sample=output)
468-
469-
def _operate_on_patched_inputs(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, added_cond_kwargs: Dict[str, Any], batch_size: int):
470-
hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # TODO(aryan): remove dtype conversion here and move to pipeline if needed
471-
472-
timestep_vid = None
473-
embedded_timestep_vid = None
474-
encoder_hidden_states_vid = None
475-
476-
if self.adaln_single is not None:
477-
if self.config.use_additional_conditions and added_cond_kwargs is None:
478-
raise ValueError(
479-
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
480-
)
481-
timestep, embedded_timestep = self.adaln_single(
482-
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype
483-
) # b 6d, b d
484-
485-
timestep_vid = timestep
486-
embedded_timestep_vid = embedded_timestep
487-
488-
if self.caption_projection is not None:
489-
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d
490-
encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d')
491-
492-
return hidden_states, encoder_hidden_states_vid, timestep_vid, embedded_timestep_vid
493-
494-
def _get_output_for_patched_inputs(
495-
self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None
496-
) -> torch.Tensor:
497-
if self.config.norm_type != "ada_norm_single":
498-
conditioning = self.transformer_blocks[0].norm1.emb(
499-
timestep, class_labels, hidden_dtype=self.dtype
500-
)
501-
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
502-
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
503-
hidden_states = self.proj_out_2(hidden_states)
504-
elif self.config.norm_type == "ada_norm_single":
505-
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
506-
hidden_states = self.norm_out(hidden_states)
507-
# Modulation
508-
hidden_states = hidden_states * (1 + scale) + shift
509-
hidden_states = self.proj_out(hidden_states)
510-
hidden_states = hidden_states.squeeze(1)
511-
512-
# unpatchify
513-
if self.adaln_single is None:
514-
height = width = int(hidden_states.shape[1] ** 0.5)
515-
hidden_states = hidden_states.reshape(
516-
shape=(-1, num_frames, height, width, self.config.patch_size_temporal, self.config.patch_size, self.config.patch_size, self.out_channels)
517-
)
518-
hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
519-
output = hidden_states.reshape(-1, self.out_channels, num_frames * self.config.patch_size_temporal, height * self.config.patch_size, width * self.config.patch_size)
520-
return output

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -808,12 +808,10 @@ def __call__(
808808
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
809809
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
810810

811-
# 6.1 Prepare micro-conditions.
812-
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
813-
811+
# 7. Prepare rotary embeddings
814812
image_rotary_emb = self._prepare_rotary_positional_embeddings(batch_size, height, width, latents.size(2), device)
815813

816-
# 7. Denoising loop
814+
# 8. Denoising loop
817815
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
818816

819817
progress_wrap = tqdm.tqdm if verbose else (lambda x: x)
@@ -853,7 +851,6 @@ def __call__(
853851
encoder_hidden_states=prompt_embeds,
854852
encoder_attention_mask=prompt_attention_mask,
855853
timestep=current_timestep,
856-
added_cond_kwargs=added_cond_kwargs,
857854
image_rotary_emb=image_rotary_emb,
858855
return_dict=False,
859856
)[0]

0 commit comments

Comments
 (0)