Skip to content

Commit c76dc5a

Browse files
committed
refactor part 6
1 parent 8f9ffa8 commit c76dc5a

File tree

2 files changed

+92
-71
lines changed

2 files changed

+92
-71
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 91 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
import math
1919
import re
2020
import urllib.parse as ul
21-
from typing import Callable, List, Optional, Tuple, Union
21+
from typing import Callable, Dict, List, Optional, Tuple, Union
2222

2323
import torch
24-
import tqdm
2524
from transformers import T5EncoderModel, T5Tokenizer
2625

26+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2727
from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro
2828
from ...models.embeddings import get_3d_rotary_pos_embed_allegro
2929
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -171,6 +171,12 @@ class AllegroPipeline(DiffusionPipeline):
171171
_optional_components = ["tokenizer", "text_encoder", "vae", "transformer", "scheduler"]
172172
model_cpu_offload_seq = "text_encoder->transformer->vae"
173173

174+
_callback_tensor_inputs = [
175+
"latents",
176+
"prompt_embeds",
177+
"negative_prompt_embeds",
178+
]
179+
174180
def __init__(
175181
self,
176182
tokenizer: T5Tokenizer,
@@ -198,7 +204,7 @@ def encode_prompt(
198204
prompt: Union[str, List[str]],
199205
do_classifier_free_guidance: bool = True,
200206
negative_prompt: str = "",
201-
num_images_per_prompt: int = 1,
207+
num_videos_per_prompt: int = 1,
202208
device: Optional[torch.device] = None,
203209
prompt_embeds: Optional[torch.FloatTensor] = None,
204210
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -286,10 +292,10 @@ def encode_prompt(
286292

287293
bs_embed, seq_len, _ = prompt_embeds.shape
288294
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
289-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
290-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
295+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
296+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
291297
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
292-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
298+
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
293299

294300
# get unconditional embeddings for classifier free guidance
295301
if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -320,11 +326,11 @@ def encode_prompt(
320326

321327
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
322328

323-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
324-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
329+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
330+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
325331

326332
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
327-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
333+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
328334
else:
329335
negative_prompt_embeds = None
330336
negative_prompt_attention_mask = None
@@ -355,8 +361,8 @@ def check_inputs(
355361
num_frames,
356362
height,
357363
width,
358-
negative_prompt,
359-
callback_steps,
364+
callback_on_step_end_tensor_inputs,
365+
negative_prompt=None,
360366
prompt_embeds=None,
361367
negative_prompt_embeds=None,
362368
prompt_attention_mask=None,
@@ -367,12 +373,11 @@ def check_inputs(
367373
if height % 8 != 0 or width % 8 != 0:
368374
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
369375

370-
if (callback_steps is None) or (
371-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
376+
if callback_on_step_end_tensor_inputs is not None and not all(
377+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
372378
):
373379
raise ValueError(
374-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
375-
f" {type(callback_steps)}."
380+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
376381
)
377382

378383
if prompt is not None and prompt_embeds is not None:
@@ -606,20 +611,16 @@ def _prepare_rotary_positional_embeddings(
606611
num_frames: int,
607612
device: torch.device,
608613
):
609-
attention_head_dim = 96
610-
vae_scale_factor_spatial = 8
611-
patch_size = 2
612-
613-
grid_height = height // (vae_scale_factor_spatial * patch_size)
614-
grid_width = width // (vae_scale_factor_spatial * patch_size)
615-
base_size_width = 1280 // (vae_scale_factor_spatial * patch_size)
616-
base_size_height = 720 // (vae_scale_factor_spatial * patch_size)
614+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
615+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
616+
base_size_width = 1280 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
617+
base_size_height = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
617618

618619
grid_crops_coords = get_resize_crop_region_for_grid(
619620
(grid_height, grid_width), base_size_width, base_size_height
620621
)
621622
freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro(
622-
embed_dim=attention_head_dim,
623+
embed_dim=self.transformer.config.attention_head_dim,
623624
crops_coords=grid_crops_coords,
624625
grid_size=(grid_height, grid_width),
625626
temporal_size=num_frames,
@@ -653,10 +654,10 @@ def __call__(
653654
num_inference_steps: int = 100,
654655
timesteps: List[int] = None,
655656
guidance_scale: float = 7.5,
656-
num_images_per_prompt: Optional[int] = 1,
657657
num_frames: Optional[int] = None,
658658
height: Optional[int] = None,
659659
width: Optional[int] = None,
660+
num_videos_per_prompt: int = 1,
660661
eta: float = 0.0,
661662
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
662663
latents: Optional[torch.FloatTensor] = None,
@@ -666,11 +667,12 @@ def __call__(
666667
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
667668
output_type: Optional[str] = "pil",
668669
return_dict: bool = True,
669-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
670-
callback_steps: int = 1,
670+
callback_on_step_end: Optional[
671+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
672+
] = None,
673+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
671674
clean_caption: bool = True,
672675
max_sequence_length: int = 300,
673-
verbose: bool = True,
674676
) -> Union[AllegroPipelineOutput, Tuple]:
675677
"""
676678
Function invoked when calling the pipeline for generation.
@@ -746,6 +748,12 @@ def __call__(
746748
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
747749
returned where the first element is a list with the generated images
748750
"""
751+
752+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
753+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
754+
755+
num_videos_per_prompt = 1
756+
749757
# 1. Check inputs. Raise error if not correct
750758
num_frames = num_frames or self.transformer.config.sample_size_t * self.vae_scale_factor_temporal
751759
height = height or self.transformer.config.sample_size[0] * self.vae_scale_factor_spatial
@@ -756,13 +764,15 @@ def __call__(
756764
num_frames,
757765
height,
758766
width,
767+
callback_on_step_end_tensor_inputs,
759768
negative_prompt,
760-
callback_steps,
761769
prompt_embeds,
762770
negative_prompt_embeds,
763771
prompt_attention_mask,
764772
negative_prompt_attention_mask,
765773
)
774+
self._guidance_scale = guidance_scale
775+
self._interrupt = False
766776

767777
# 2. Default height and width to transformer
768778
if prompt is not None and isinstance(prompt, str):
@@ -789,7 +799,7 @@ def __call__(
789799
prompt,
790800
do_classifier_free_guidance,
791801
negative_prompt=negative_prompt,
792-
num_images_per_prompt=num_images_per_prompt,
802+
num_videos_per_prompt=num_videos_per_prompt,
793803
device=device,
794804
prompt_embeds=prompt_embeds,
795805
negative_prompt_embeds=negative_prompt_embeds,
@@ -809,7 +819,7 @@ def __call__(
809819
# 5. Prepare latents.
810820
latent_channels = self.transformer.config.in_channels
811821
latents = self.prepare_latents(
812-
batch_size * num_images_per_prompt,
822+
batch_size * num_videos_per_prompt,
813823
latent_channels,
814824
num_frames,
815825
height,
@@ -831,45 +841,56 @@ def __call__(
831841
# 8. Denoising loop
832842
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
833843

834-
progress_wrap = tqdm.tqdm if verbose else (lambda x: x)
835-
for i, t in progress_wrap(list(enumerate(timesteps))):
836-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
837-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
838-
839-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
840-
timestep = t.expand(latent_model_input.shape[0])
841-
842-
if prompt_embeds.ndim == 3:
843-
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
844-
845-
# prepare attention_mask.
846-
# b c t h w -> b t h w
847-
attention_mask = torch.ones_like(latent_model_input)[:, 0]
848-
849-
# predict noise model_output
850-
noise_pred = self.transformer(
851-
latent_model_input,
852-
attention_mask=attention_mask,
853-
encoder_hidden_states=prompt_embeds,
854-
encoder_attention_mask=prompt_attention_mask,
855-
timestep=timestep,
856-
image_rotary_emb=image_rotary_emb,
857-
return_dict=False,
858-
)[0]
859-
860-
# perform guidance
861-
if do_classifier_free_guidance:
862-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
863-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
864-
865-
# compute previous image: x_t -> x_t-1
866-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
867-
868-
# call the callback, if provided
869-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
870-
if callback is not None and i % callback_steps == 0:
871-
step_idx = i // getattr(self.scheduler, "order", 1)
872-
callback(step_idx, t, latents)
844+
with self.progress_bar(total=num_inference_steps) as progress_bar:
845+
for i, t in enumerate(timesteps):
846+
if self.interrupt:
847+
continue
848+
849+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
850+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
851+
852+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
853+
timestep = t.expand(latent_model_input.shape[0])
854+
855+
if prompt_embeds.ndim == 3:
856+
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
857+
858+
# prepare attention_mask.
859+
# b c t h w -> b t h w
860+
attention_mask = torch.ones_like(latent_model_input)[:, 0]
861+
862+
# predict noise model_output
863+
noise_pred = self.transformer(
864+
latent_model_input,
865+
attention_mask=attention_mask,
866+
encoder_hidden_states=prompt_embeds,
867+
encoder_attention_mask=prompt_attention_mask,
868+
timestep=timestep,
869+
image_rotary_emb=image_rotary_emb,
870+
return_dict=False,
871+
)[0]
872+
873+
# perform guidance
874+
if do_classifier_free_guidance:
875+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
876+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
877+
878+
# compute previous image: x_t -> x_t-1
879+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
880+
881+
# call the callback, if provided
882+
if callback_on_step_end is not None:
883+
callback_kwargs = {}
884+
for k in callback_on_step_end_tensor_inputs:
885+
callback_kwargs[k] = locals()[k]
886+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
887+
888+
latents = callback_outputs.pop("latents", latents)
889+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
890+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
891+
892+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
893+
progress_bar.update()
873894

874895
if not output_type == "latent":
875896
latents = latents.to(self.vae.dtype)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,8 @@ def __call__(
649649
height,
650650
width,
651651
prompt_embeds.dtype,
652-
device,
653652
generator,
653+
device,
654654
latents,
655655
)
656656

0 commit comments

Comments
 (0)