2121from transformers import T5EncoderModel , T5TokenizerFast
2222
2323from ...callbacks import MultiPipelineCallbacks , PipelineCallback
24- from ...models .autoencoders import AutoencoderKLMochi
24+ from ...models .autoencoders import AutoencoderKL
2525from ...models .transformers import MochiTransformer3DModel
2626from ...schedulers import FlowMatchEulerDiscreteScheduler
2727from ...utils import (
5656 >>> pipe.enable_model_cpu_offload()
5757 >>> pipe.enable_vae_tiling()
5858 >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
59- >>> frames = pipe(prompt, num_inference_steps=50 , guidance_scale=3.5).frames[0]
59+ >>> frames = pipe(prompt, num_inference_steps=28 , guidance_scale=3.5).frames[0]
6060 >>> export_to_video(frames, "mochi.mp4")
6161 ```
6262"""
@@ -164,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
164164 Conditional Transformer architecture to denoise the encoded video latents.
165165 scheduler ([`FlowMatchEulerDiscreteScheduler`]):
166166 A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
167- vae ([`AutoencoderKLMochi `]):
168- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
167+ vae ([`AutoencoderKL `]):
168+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
169169 text_encoder ([`T5EncoderModel`]):
170170 [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
171171 the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
184184 def __init__ (
185185 self ,
186186 scheduler : FlowMatchEulerDiscreteScheduler ,
187- vae : AutoencoderKLMochi ,
187+ vae : AutoencoderKL ,
188188 text_encoder : T5EncoderModel ,
189189 tokenizer : T5TokenizerFast ,
190190 transformer : MochiTransformer3DModel ,
@@ -198,11 +198,17 @@ def __init__(
198198 transformer = transformer ,
199199 scheduler = scheduler ,
200200 )
201-
202- self .vae_scale_factor_spatial = vae .spatial_compression_ratio if hasattr (self , "vae" ) else 8
203- self .vae_scale_factor_temporal = vae .temporal_compression_ratio if hasattr (self , "vae" ) else 6
204-
205- self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
201+ # TODO: determine these scaling factors from model parameters
202+ self .vae_spatial_scale_factor = 8
203+ self .vae_temporal_scale_factor = 6
204+ self .patch_size = 2
205+
206+ self .video_processor = VideoProcessor (vae_scale_factor = self .vae_spatial_scale_factor )
207+ self .tokenizer_max_length = (
208+ self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
209+ )
210+ self .default_height = 480
211+ self .default_width = 848
206212
207213 # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
208214 def _get_t5_prompt_embeds (
@@ -253,6 +259,14 @@ def _get_t5_prompt_embeds(
253259
254260 return prompt_embeds , prompt_attention_mask
255261
262+ def prepare_joint_attention_mask (self , prompt_attention_mask , latents ):
263+ batch_size , channels , latent_frames , latent_height , latent_width = latents .shape
264+ num_latents = latent_frames * latent_height * latent_width
265+ num_visual_tokens = num_latents // (self .transformer .config .patch_size ** 2 )
266+ mask = F .pad (prompt_attention_mask , (num_visual_tokens , 0 ), value = True )
267+
268+ return mask
269+
256270 # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
257271 def encode_prompt (
258272 self ,
@@ -335,12 +349,7 @@ def encode_prompt(
335349 dtype = dtype ,
336350 )
337351
338- return (
339- prompt_embeds ,
340- prompt_attention_mask ,
341- negative_prompt_embeds ,
342- negative_prompt_attention_mask ,
343- )
352+ return prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask
344353
345354 def check_inputs (
346355 self ,
@@ -424,13 +433,6 @@ def disable_vae_tiling(self):
424433 """
425434 self .vae .disable_tiling ()
426435
427- def prepare_joint_attention_mask (self , prompt_attention_mask , latents ):
428- batch_size , channels , latent_frames , latent_height , latent_width = latents .shape
429- num_latents = latent_frames * latent_height * latent_width
430- num_visual_tokens = num_latents // (self .transformer .config .patch_size ** 2 )
431- mask = F .pad (prompt_attention_mask , (num_visual_tokens , 0 ), value = True )
432- return mask
433-
434436 def prepare_latents (
435437 self ,
436438 batch_size ,
@@ -443,9 +445,9 @@ def prepare_latents(
443445 generator ,
444446 latents = None ,
445447 ):
446- height = height // self .vae_scale_factor_spatial
447- width = width // self .vae_scale_factor_spatial
448- num_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
448+ height = height // self .vae_spatial_scale_factor
449+ width = width // self .vae_spatial_scale_factor
450+ num_frames = (num_frames - 1 ) // self .vae_temporal_scale_factor + 1
449451
450452 shape = (batch_size , num_channels_latents , num_frames , height , width )
451453
@@ -485,7 +487,7 @@ def __call__(
485487 height : Optional [int ] = None ,
486488 width : Optional [int ] = None ,
487489 num_frames : int = 19 ,
488- num_inference_steps : int = 50 ,
490+ num_inference_steps : int = 64 ,
489491 timesteps : List [int ] = None ,
490492 guidance_scale : float = 4.5 ,
491493 num_videos_per_prompt : Optional [int ] = 1 ,
@@ -508,13 +510,13 @@ def __call__(
508510 prompt (`str` or `List[str]`, *optional*):
509511 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
510512 instead.
511- height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio `):
513+ height (`int`, *optional*, defaults to `self.default_height `):
512514 The height in pixels of the generated image. This is set to 480 by default for the best results.
513- width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio `):
515+ width (`int`, *optional*, defaults to `self.default_width `):
514516 The width in pixels of the generated image. This is set to 848 by default for the best results.
515517 num_frames (`int`, defaults to `19`):
516518 The number of video frames to generate
517- num_inference_steps (`int`, *optional*, defaults to `50` ):
519+ num_inference_steps (`int`, *optional*, defaults to 50 ):
518520 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
519521 expense of slower inference.
520522 timesteps (`List[int]`, *optional*):
@@ -574,8 +576,8 @@ def __call__(
574576 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
575577 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
576578
577- height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
578- width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
579+ height = height or self .default_height
580+ width = width or self .default_width
579581
580582 # 1. Check inputs. Raise error if not correct
581583 self .check_inputs (
@@ -601,6 +603,7 @@ def __call__(
601603 batch_size = prompt_embeds .shape [0 ]
602604
603605 device = self ._execution_device
606+
604607 # 3. Prepare text embeddings
605608 (
606609 prompt_embeds ,
@@ -619,10 +622,6 @@ def __call__(
619622 max_sequence_length = max_sequence_length ,
620623 device = device ,
621624 )
622- # if self.do_classifier_free_guidance:
623- # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
624- # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
625-
626625 # 4. Prepare latent variables
627626 num_channels_latents = self .transformer .config .in_channels
628627 latents = self .prepare_latents (
@@ -636,16 +635,20 @@ def __call__(
636635 generator ,
637636 latents ,
638637 )
638+ joint_attention_mask = self .prepare_joint_attention_mask (prompt_attention_mask , latents )
639+ negative_joint_attention_mask = self .prepare_joint_attention_mask (negative_prompt_attention_mask , latents )
640+
641+ if self .do_classifier_free_guidance :
642+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
643+ prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
644+ joint_attention_mask = torch .cat ([negative_joint_attention_mask , joint_attention_mask ], dim = 0 )
639645
640646 # 5. Prepare timestep
641647 # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
642648 threshold_noise = 0.025
643649 sigmas = linear_quadratic_schedule (num_inference_steps , threshold_noise )
644650 sigmas = np .array (sigmas )
645651
646- joint_attention_mask = self .prepare_joint_attention_mask (prompt_attention_mask , latents )
647- negative_joint_attention_mask = self .prepare_joint_attention_mask (negative_prompt_attention_mask , latents )
648-
649652 timesteps , num_inference_steps = retrieve_timesteps (
650653 self .scheduler ,
651654 num_inference_steps ,
@@ -662,40 +665,28 @@ def __call__(
662665 if self .interrupt :
663666 continue
664667
665- # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
666- # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
667- # timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
668-
669- latent_model_input = latents
668+ latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
669+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
670670 timestep = t .expand (latent_model_input .shape [0 ]).to (latents .dtype )
671671
672- noise_pred_text = self .transformer (
672+ noise_pred = self .transformer (
673673 hidden_states = latent_model_input ,
674674 encoder_hidden_states = prompt_embeds ,
675675 timestep = timestep ,
676676 encoder_attention_mask = prompt_attention_mask ,
677677 joint_attention_mask = joint_attention_mask ,
678678 return_dict = False ,
679679 )[0 ]
680+ # Mochi CFG + Sampling runs in FP32
681+ noise_pred = noise_pred .to (torch .float32 )
680682
681683 if self .do_classifier_free_guidance :
682- noise_pred_uncond = self .transformer (
683- hidden_states = latent_model_input ,
684- encoder_hidden_states = negative_prompt_embeds ,
685- timestep = timestep ,
686- encoder_attention_mask = negative_prompt_attention_mask ,
687- joint_attention_mask = negative_joint_attention_mask ,
688- return_dict = False ,
689- )[0 ]
690- noise_pred = noise_pred_uncond .float () + self .guidance_scale * (
691- noise_pred_text .float () - noise_pred_uncond .float ()
692- )
693- else :
694- noise_pred = noise_pred_text
684+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
685+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
695686
696687 # compute the previous noisy sample x_t -> x_t-1
697688 latents_dtype = latents .dtype
698- latents = self .scheduler .step (noise_pred , t , latents .float ( ), return_dict = False )[0 ]
689+ latents = self .scheduler .step (noise_pred , t , latents .to ( torch . float32 ), return_dict = False )[0 ]
699690 latents = latents .to (latents_dtype )
700691
701692 if latents .dtype != latents_dtype :
@@ -718,33 +709,27 @@ def __call__(
718709
719710 if XLA_AVAILABLE :
720711 xm .mark_step ()
712+
721713 if output_type == "latent" :
722714 video = latents
723715 else :
724- with torch .autocast ("cuda" , torch .float32 ):
725- # unscale/denormalize the latents
726- # denormalize with the mean and std if available and not None
727- has_latents_mean = (
728- hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None
716+ # unscale/denormalize the latents
717+ # denormalize with the mean and std if available and not None
718+ has_latents_mean = hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None
719+ has_latents_std = hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None
720+ if has_latents_mean and has_latents_std :
721+ latents_mean = (
722+ torch .tensor (self .vae .config .latents_mean ).view (1 , 12 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
729723 )
730- has_latents_std = hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None
731- if has_latents_mean and has_latents_std :
732- latents_mean = (
733- torch .tensor (self .vae .config .latents_mean )
734- .view (1 , 12 , 1 , 1 , 1 )
735- .to (latents .device , latents .dtype )
736- )
737- latents_std = (
738- torch .tensor (self .vae .config .latents_std )
739- .view (1 , 12 , 1 , 1 , 1 )
740- .to (latents .device , latents .dtype )
741- )
742- latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
743- else :
744- latents = latents / self .vae .config .scaling_factor
745-
746- video = self .vae .decode (latents , return_dict = False )[0 ]
747- video = self .video_processor .postprocess_video (video , output_type = output_type )
724+ latents_std = (
725+ torch .tensor (self .vae .config .latents_std ).view (1 , 12 , 1 , 1 , 1 ).to (latents .device , latents .dtype )
726+ )
727+ latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
728+ else :
729+ latents = latents / self .vae .config .scaling_factor
730+
731+ video = self .vae .decode (latents , return_dict = False )[0 ]
732+ video = self .video_processor .postprocess_video (video , output_type = output_type )
748733
749734 # Offload all models
750735 self .maybe_free_model_hooks ()
0 commit comments