Skip to content

Commit 900fead

Browse files
committed
update
1 parent 2cfca5e commit 900fead

File tree

2 files changed

+80
-112
lines changed

2 files changed

+80
-112
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3554,11 +3554,11 @@ def __call__(
35543554
if image_rotary_emb is not None:
35553555

35563556
def apply_rotary_emb(x, freqs_cos, freqs_sin):
3557-
x_even = x[..., 0::2]
3558-
x_odd = x[..., 1::2]
3557+
x_even = x[..., 0::2].float()
3558+
x_odd = x[..., 1::2].float()
35593559

3560-
cos = (x_even * freqs_cos.float() - x_odd * freqs_sin.float()).to(x.dtype)
3561-
sin = (x_even * freqs_sin.float() + x_odd * freqs_cos.float()).to(x.dtype)
3560+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
3561+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
35623562

35633563
return torch.stack([cos, sin], dim=-1).flatten(-2)
35643564

@@ -3572,40 +3572,23 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
35723572
encoder_value.transpose(1, 2),
35733573
)
35743574

3575-
batch_size, heads, sequence_length, dim = query.shape
3576-
encoder_sequence_length = encoder_query.shape[2]
3577-
total_length = sequence_length + encoder_sequence_length
3575+
sequence_length = query.size(2)
3576+
encoder_sequence_length = encoder_query.size(2)
35783577

35793578
query = torch.cat([query, encoder_query], dim=2)
35803579
key = torch.cat([key, encoder_key], dim=2)
35813580
value = torch.cat([value, encoder_value], dim=2)
35823581

35833582
# Zero out tokens based on the attention mask
3584-
# query = query * attention_mask[:, None, :, None]
3585-
# key = key * attention_mask[:, None, :, None]
3586-
# value = value * attention_mask[:, None, :, None]
3583+
query = query * attention_mask[:, None, :, None]
3584+
key = key * attention_mask[:, None, :, None]
3585+
value = value * attention_mask[:, None, :, None]
35873586

3588-
query = query.view(1, query.size(1), -1, query.size(-1))
3589-
key = key.view(1, key.size(1), -1, key.size(-1))
3590-
value = value.view(1, value.size(1), -1, key.size(-1))
3591-
3592-
select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
3593-
3594-
query = torch.index_select(query, 2, select_index)
3595-
key = torch.index_select(key, 2, select_index)
3596-
value = torch.index_select(value, 2, select_index)
3597-
3598-
from torch.nn.attention import SDPBackend, sdpa_kernel
3599-
3600-
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
3601-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
3587+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
36023588

3603-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0)
3604-
output = torch.zeros(
3605-
batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype
3606-
)
3607-
output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states)
3608-
hidden_states = output.view(batch_size, total_length, dim * heads)
3589+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
3590+
# Zero out tokens based on attention mask
3591+
hidden_states = hidden_states * attention_mask[:, :, None]
36093592

36103593
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
36113594
(sequence_length, encoder_sequence_length), dim=1

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 67 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from transformers import T5EncoderModel, T5TokenizerFast
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24-
from ...models.autoencoders import AutoencoderKLMochi
24+
from ...models.autoencoders import AutoencoderKL
2525
from ...models.transformers import MochiTransformer3DModel
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import (
@@ -56,7 +56,7 @@
5656
>>> pipe.enable_model_cpu_offload()
5757
>>> pipe.enable_vae_tiling()
5858
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
59-
>>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0]
59+
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
6060
>>> export_to_video(frames, "mochi.mp4")
6161
```
6262
"""
@@ -164,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
164164
Conditional Transformer architecture to denoise the encoded video latents.
165165
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
166166
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
167-
vae ([`AutoencoderKLMochi`]):
168-
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
167+
vae ([`AutoencoderKL`]):
168+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
169169
text_encoder ([`T5EncoderModel`]):
170170
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
171171
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
184184
def __init__(
185185
self,
186186
scheduler: FlowMatchEulerDiscreteScheduler,
187-
vae: AutoencoderKLMochi,
187+
vae: AutoencoderKL,
188188
text_encoder: T5EncoderModel,
189189
tokenizer: T5TokenizerFast,
190190
transformer: MochiTransformer3DModel,
@@ -198,11 +198,17 @@ def __init__(
198198
transformer=transformer,
199199
scheduler=scheduler,
200200
)
201-
202-
self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8
203-
self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6
204-
205-
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
201+
# TODO: determine these scaling factors from model parameters
202+
self.vae_spatial_scale_factor = 8
203+
self.vae_temporal_scale_factor = 6
204+
self.patch_size = 2
205+
206+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
207+
self.tokenizer_max_length = (
208+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
209+
)
210+
self.default_height = 480
211+
self.default_width = 848
206212

207213
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
208214
def _get_t5_prompt_embeds(
@@ -253,6 +259,14 @@ def _get_t5_prompt_embeds(
253259

254260
return prompt_embeds, prompt_attention_mask
255261

262+
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
263+
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
264+
num_latents = latent_frames * latent_height * latent_width
265+
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
266+
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
267+
268+
return mask
269+
256270
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
257271
def encode_prompt(
258272
self,
@@ -335,12 +349,7 @@ def encode_prompt(
335349
dtype=dtype,
336350
)
337351

338-
return (
339-
prompt_embeds,
340-
prompt_attention_mask,
341-
negative_prompt_embeds,
342-
negative_prompt_attention_mask,
343-
)
352+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
344353

345354
def check_inputs(
346355
self,
@@ -424,13 +433,6 @@ def disable_vae_tiling(self):
424433
"""
425434
self.vae.disable_tiling()
426435

427-
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
428-
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
429-
num_latents = latent_frames * latent_height * latent_width
430-
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
431-
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
432-
return mask
433-
434436
def prepare_latents(
435437
self,
436438
batch_size,
@@ -443,9 +445,9 @@ def prepare_latents(
443445
generator,
444446
latents=None,
445447
):
446-
height = height // self.vae_scale_factor_spatial
447-
width = width // self.vae_scale_factor_spatial
448-
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
448+
height = height // self.vae_spatial_scale_factor
449+
width = width // self.vae_spatial_scale_factor
450+
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
449451

450452
shape = (batch_size, num_channels_latents, num_frames, height, width)
451453

@@ -485,7 +487,7 @@ def __call__(
485487
height: Optional[int] = None,
486488
width: Optional[int] = None,
487489
num_frames: int = 19,
488-
num_inference_steps: int = 50,
490+
num_inference_steps: int = 64,
489491
timesteps: List[int] = None,
490492
guidance_scale: float = 4.5,
491493
num_videos_per_prompt: Optional[int] = 1,
@@ -508,13 +510,13 @@ def __call__(
508510
prompt (`str` or `List[str]`, *optional*):
509511
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
510512
instead.
511-
height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`):
513+
height (`int`, *optional*, defaults to `self.default_height`):
512514
The height in pixels of the generated image. This is set to 480 by default for the best results.
513-
width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`):
515+
width (`int`, *optional*, defaults to `self.default_width`):
514516
The width in pixels of the generated image. This is set to 848 by default for the best results.
515517
num_frames (`int`, defaults to `19`):
516518
The number of video frames to generate
517-
num_inference_steps (`int`, *optional*, defaults to `50`):
519+
num_inference_steps (`int`, *optional*, defaults to 50):
518520
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
519521
expense of slower inference.
520522
timesteps (`List[int]`, *optional*):
@@ -574,8 +576,8 @@ def __call__(
574576
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
575577
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
576578

577-
height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
578-
width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
579+
height = height or self.default_height
580+
width = width or self.default_width
579581

580582
# 1. Check inputs. Raise error if not correct
581583
self.check_inputs(
@@ -601,6 +603,7 @@ def __call__(
601603
batch_size = prompt_embeds.shape[0]
602604

603605
device = self._execution_device
606+
604607
# 3. Prepare text embeddings
605608
(
606609
prompt_embeds,
@@ -619,10 +622,6 @@ def __call__(
619622
max_sequence_length=max_sequence_length,
620623
device=device,
621624
)
622-
# if self.do_classifier_free_guidance:
623-
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
624-
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
625-
626625
# 4. Prepare latent variables
627626
num_channels_latents = self.transformer.config.in_channels
628627
latents = self.prepare_latents(
@@ -636,16 +635,20 @@ def __call__(
636635
generator,
637636
latents,
638637
)
638+
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
639+
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
640+
641+
if self.do_classifier_free_guidance:
642+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
643+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
644+
joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0)
639645

640646
# 5. Prepare timestep
641647
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
642648
threshold_noise = 0.025
643649
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
644650
sigmas = np.array(sigmas)
645651

646-
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
647-
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
648-
649652
timesteps, num_inference_steps = retrieve_timesteps(
650653
self.scheduler,
651654
num_inference_steps,
@@ -662,40 +665,28 @@ def __call__(
662665
if self.interrupt:
663666
continue
664667

665-
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
666-
# # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
667-
# timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
668-
669-
latent_model_input = latents
668+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
669+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
670670
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
671671

672-
noise_pred_text = self.transformer(
672+
noise_pred = self.transformer(
673673
hidden_states=latent_model_input,
674674
encoder_hidden_states=prompt_embeds,
675675
timestep=timestep,
676676
encoder_attention_mask=prompt_attention_mask,
677677
joint_attention_mask=joint_attention_mask,
678678
return_dict=False,
679679
)[0]
680+
# Mochi CFG + Sampling runs in FP32
681+
noise_pred = noise_pred.to(torch.float32)
680682

681683
if self.do_classifier_free_guidance:
682-
noise_pred_uncond = self.transformer(
683-
hidden_states=latent_model_input,
684-
encoder_hidden_states=negative_prompt_embeds,
685-
timestep=timestep,
686-
encoder_attention_mask=negative_prompt_attention_mask,
687-
joint_attention_mask=negative_joint_attention_mask,
688-
return_dict=False,
689-
)[0]
690-
noise_pred = noise_pred_uncond.float() + self.guidance_scale * (
691-
noise_pred_text.float() - noise_pred_uncond.float()
692-
)
693-
else:
694-
noise_pred = noise_pred_text
684+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
685+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
695686

696687
# compute the previous noisy sample x_t -> x_t-1
697688
latents_dtype = latents.dtype
698-
latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0]
689+
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
699690
latents = latents.to(latents_dtype)
700691

701692
if latents.dtype != latents_dtype:
@@ -718,33 +709,27 @@ def __call__(
718709

719710
if XLA_AVAILABLE:
720711
xm.mark_step()
712+
721713
if output_type == "latent":
722714
video = latents
723715
else:
724-
with torch.autocast("cuda", torch.float32):
725-
# unscale/denormalize the latents
726-
# denormalize with the mean and std if available and not None
727-
has_latents_mean = (
728-
hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
716+
# unscale/denormalize the latents
717+
# denormalize with the mean and std if available and not None
718+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
719+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
720+
if has_latents_mean and has_latents_std:
721+
latents_mean = (
722+
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
729723
)
730-
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
731-
if has_latents_mean and has_latents_std:
732-
latents_mean = (
733-
torch.tensor(self.vae.config.latents_mean)
734-
.view(1, 12, 1, 1, 1)
735-
.to(latents.device, latents.dtype)
736-
)
737-
latents_std = (
738-
torch.tensor(self.vae.config.latents_std)
739-
.view(1, 12, 1, 1, 1)
740-
.to(latents.device, latents.dtype)
741-
)
742-
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
743-
else:
744-
latents = latents / self.vae.config.scaling_factor
745-
746-
video = self.vae.decode(latents, return_dict=False)[0]
747-
video = self.video_processor.postprocess_video(video, output_type=output_type)
724+
latents_std = (
725+
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
726+
)
727+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
728+
else:
729+
latents = latents / self.vae.config.scaling_factor
730+
731+
video = self.vae.decode(latents, return_dict=False)[0]
732+
video = self.video_processor.postprocess_video(video, output_type=output_type)
748733

749734
# Offload all models
750735
self.maybe_free_model_hooks()

0 commit comments

Comments
 (0)