@@ -104,6 +104,7 @@ def retrieve_timesteps(
104104 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
105105
106106 custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107+
107108 Args:
108109 scheduler (`SchedulerMixin`):
109110 The scheduler to get timesteps from.
@@ -272,8 +273,7 @@ def encode_prompt(
272273 prompt_embeds : Optional [torch .FloatTensor ] = None ,
273274 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
274275 max_sequence_length : int = 512 ,
275- do_classifier_free_guidance = True ,
276- lora_scale : Optional [float ] = None ,
276+ do_classifier_free_guidance : bool = True ,
277277 ):
278278 r"""
279279
@@ -305,14 +305,12 @@ def encode_prompt(
305305 device = device ,
306306 )
307307
308- dtype = self . text_encoder . dtype if self .text_encoder is not None else self . transformer . dtype
308+ prompt_embeds = prompt_embeds . to ( self .text_encoder . dtype )
309309
310- # TODO: Add negative prompts back
311310 if do_classifier_free_guidance and negative_prompt_embeds is None :
312311 negative_prompt = negative_prompt or ""
313312 # normalize str to list
314313 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
315- )
316314
317315 if prompt is not None and type (prompt ) is not type (negative_prompt ):
318316 raise TypeError (
@@ -332,6 +330,7 @@ def encode_prompt(
332330 max_sequence_length = max_sequence_length ,
333331 device = device ,
334332 )
333+ negative_prompt_embeds = negative_prompt_embeds .to (self .text_encoder .dtype )
335334
336335 return prompt_embeds , negative_prompt_embeds
337336
@@ -532,9 +531,9 @@ def __call__(
532531 Examples:
533532
534533 Returns:
535- [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
536- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
537- images.
534+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if
535+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
536+ generated images.
538537 """
539538 height = height or self .default_height
540539 width = width or self .default_width
@@ -595,21 +594,12 @@ def __call__(
595594 threshold_noise = 0.025
596595 sigmas = linear_quadratic_schedule (num_inference_steps , threshold_noise )
597596
598- image_seq_len = latents .shape [1 ]
599- mu = calculate_shift (
600- image_seq_len ,
601- self .scheduler .config .base_image_seq_len ,
602- self .scheduler .config .max_image_seq_len ,
603- self .scheduler .config .base_shift ,
604- self .scheduler .config .max_shift ,
605- )
606597 timesteps , num_inference_steps = retrieve_timesteps (
607598 self .scheduler ,
608599 num_inference_steps ,
609600 device ,
610601 timesteps ,
611602 sigmas ,
612- mu = mu ,
613603 )
614604 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
615605 self ._num_timesteps = len (timesteps )
@@ -628,12 +618,16 @@ def __call__(
628618
629619 noise_pred = self .transformer (
630620 hidden_states = latent_model_input ,
631- timestep = timestep ,
621+ timestep = timestep / 1000 ,
632622 encoder_hidden_states = prompt_embeds ,
633623 joint_attention_kwargs = self .joint_attention_kwargs ,
634624 return_dict = False ,
635625 )[0 ]
636626
627+ if self .do_classifier_free_guidance :
628+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
629+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
630+
637631 # compute the previous noisy sample x_t -> x_t-1
638632 latents_dtype = latents .dtype
639633 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
@@ -660,18 +654,16 @@ def __call__(
660654 xm .mark_step ()
661655
662656 if output_type == "latent" :
663- image = latents
657+ video = latents
664658
665659 else :
666- latents = self ._unpack_latents (latents , height , width , self .vae_scale_factor )
667- latents = (latents / self .vae .config .scaling_factor ) + self .vae .config .shift_factor
668- image = self .vae .decode (latents , return_dict = False )[0 ]
669- image = self .image_processor .postprocess (image , output_type = output_type )
660+ video = self .vae .decode (latents , return_dict = False )[0 ]
661+ video = self .video_processor .postprocess (video , output_type = output_type )
670662
671663 # Offload all models
672664 self .maybe_free_model_hooks ()
673665
674666 if not return_dict :
675- return (image ,)
667+ return (video ,)
676668
677- return MochiPipelineOutput (images = image )
669+ return MochiPipelineOutput (frames = video )
0 commit comments