Skip to content

Commit 8700d64

Browse files
committed
update
1 parent ebcbad2 commit 8700d64

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def retrieve_timesteps(
104104
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
105105
106106
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107+
107108
Args:
108109
scheduler (`SchedulerMixin`):
109110
The scheduler to get timesteps from.
@@ -272,8 +273,7 @@ def encode_prompt(
272273
prompt_embeds: Optional[torch.FloatTensor] = None,
273274
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
274275
max_sequence_length: int = 512,
275-
do_classifier_free_guidance=True,
276-
lora_scale: Optional[float] = None,
276+
do_classifier_free_guidance: bool = True,
277277
):
278278
r"""
279279
@@ -305,14 +305,12 @@ def encode_prompt(
305305
device=device,
306306
)
307307

308-
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
308+
prompt_embeds = prompt_embeds.to(self.text_encoder.dtype)
309309

310-
# TODO: Add negative prompts back
311310
if do_classifier_free_guidance and negative_prompt_embeds is None:
312311
negative_prompt = negative_prompt or ""
313312
# normalize str to list
314313
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
315-
)
316314

317315
if prompt is not None and type(prompt) is not type(negative_prompt):
318316
raise TypeError(
@@ -332,6 +330,7 @@ def encode_prompt(
332330
max_sequence_length=max_sequence_length,
333331
device=device,
334332
)
333+
negative_prompt_embeds = negative_prompt_embeds.to(self.text_encoder.dtype)
335334

336335
return prompt_embeds, negative_prompt_embeds
337336

@@ -532,9 +531,9 @@ def __call__(
532531
Examples:
533532
534533
Returns:
535-
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
536-
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
537-
images.
534+
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if
535+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
536+
generated images.
538537
"""
539538
height = height or self.default_height
540539
width = width or self.default_width
@@ -595,21 +594,12 @@ def __call__(
595594
threshold_noise = 0.025
596595
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
597596

598-
image_seq_len = latents.shape[1]
599-
mu = calculate_shift(
600-
image_seq_len,
601-
self.scheduler.config.base_image_seq_len,
602-
self.scheduler.config.max_image_seq_len,
603-
self.scheduler.config.base_shift,
604-
self.scheduler.config.max_shift,
605-
)
606597
timesteps, num_inference_steps = retrieve_timesteps(
607598
self.scheduler,
608599
num_inference_steps,
609600
device,
610601
timesteps,
611602
sigmas,
612-
mu=mu,
613603
)
614604
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
615605
self._num_timesteps = len(timesteps)
@@ -628,12 +618,16 @@ def __call__(
628618

629619
noise_pred = self.transformer(
630620
hidden_states=latent_model_input,
631-
timestep=timestep,
621+
timestep=timestep / 1000,
632622
encoder_hidden_states=prompt_embeds,
633623
joint_attention_kwargs=self.joint_attention_kwargs,
634624
return_dict=False,
635625
)[0]
636626

627+
if self.do_classifier_free_guidance:
628+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
629+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
630+
637631
# compute the previous noisy sample x_t -> x_t-1
638632
latents_dtype = latents.dtype
639633
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -660,18 +654,16 @@ def __call__(
660654
xm.mark_step()
661655

662656
if output_type == "latent":
663-
image = latents
657+
video = latents
664658

665659
else:
666-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
667-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
668-
image = self.vae.decode(latents, return_dict=False)[0]
669-
image = self.image_processor.postprocess(image, output_type=output_type)
660+
video = self.vae.decode(latents, return_dict=False)[0]
661+
video = self.video_processor.postprocess(video, output_type=output_type)
670662

671663
# Offload all models
672664
self.maybe_free_model_hooks()
673665

674666
if not return_dict:
675-
return (image,)
667+
return (video,)
676668

677-
return MochiPipelineOutput(images=image)
669+
return MochiPipelineOutput(frames=video)

0 commit comments

Comments
 (0)