Skip to content

Commit ebcbad2

Browse files
committed
update
1 parent 44987ad commit ebcbad2

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -261,40 +261,41 @@ def _get_t5_prompt_embeds(
261261
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
262262
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
263263

264-
return prompt_embeds, prompt_attention_mask
264+
return prompt_embeds
265265

266266
def encode_prompt(
267267
self,
268268
prompt: Union[str, List[str]],
269+
negative_prompt: Optional[Union[str, List[str]]] = None,
269270
device: Optional[torch.device] = None,
270271
num_videos_per_prompt: int = 1,
271272
prompt_embeds: Optional[torch.FloatTensor] = None,
273+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272274
max_sequence_length: int = 512,
275+
do_classifier_free_guidance=True,
273276
lora_scale: Optional[float] = None,
274277
):
275278
r"""
276279
277280
Args:
278281
prompt (`str` or `List[str]`, *optional*):
279282
prompt to be encoded
280-
prompt_2 (`str` or `List[str]`, *optional*):
281-
The prompt or prompts to be sent to the `tokenizer` and `text_encoder`. If not defined, `prompt` is
282-
used in all text-encoders
283283
device: (`torch.device`):
284284
torch device
285285
num_videos_per_prompt (`int`):
286286
number of images that should be generated per prompt
287287
prompt_embeds (`torch.FloatTensor`, *optional*):
288288
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
289289
provided, text embeddings will be generated from `prompt` input argument.
290-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
291-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
292-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
293290
lora_scale (`float`, *optional*):
294291
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
295292
"""
296293
device = device or self._execution_device
297294
prompt = [prompt] if isinstance(prompt, str) else prompt
295+
if prompt is not None:
296+
batch_size = len(prompt)
297+
else:
298+
batch_size = prompt_embeds.shape[0]
298299

299300
if prompt_embeds is None:
300301
prompt_embeds = self._get_t5_prompt_embeds(
@@ -307,8 +308,32 @@ def encode_prompt(
307308
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
308309

309310
# TODO: Add negative prompts back
311+
if do_classifier_free_guidance and negative_prompt_embeds is None:
312+
negative_prompt = negative_prompt or ""
313+
# normalize str to list
314+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
315+
)
310316

311-
return prompt_embeds
317+
if prompt is not None and type(prompt) is not type(negative_prompt):
318+
raise TypeError(
319+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
320+
f" {type(prompt)}."
321+
)
322+
elif batch_size != len(negative_prompt):
323+
raise ValueError(
324+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
325+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
326+
" the batch size of `prompt`."
327+
)
328+
329+
negative_prompt_embeds = self._get_t5_prompt_embeds(
330+
prompt=negative_prompt,
331+
num_videos_per_prompt=num_videos_per_prompt,
332+
max_sequence_length=max_sequence_length,
333+
device=device,
334+
)
335+
336+
return prompt_embeds, negative_prompt_embeds
312337

313338
def check_inputs(
314339
self,
@@ -541,7 +566,7 @@ def __call__(
541566
lora_scale = (
542567
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
543568
)
544-
(prompt_embeds) = self.encode_prompt(
569+
(prompt_embeds, negative_prompt_embeds) = self.encode_prompt(
545570
prompt=prompt,
546571
prompt_embeds=prompt_embeds,
547572
device=device,
@@ -589,12 +614,8 @@ def __call__(
589614
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
590615
self._num_timesteps = len(timesteps)
591616

592-
# handle guidance
593-
if self.transformer.config.guidance_embeds:
594-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
595-
guidance = guidance.expand(latents.shape[0])
596-
else:
597-
guidance = None
617+
if self.do_classifier_free_guidance:
618+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
598619

599620
# 6. Denoising loop
600621
with self.progress_bar(total=num_inference_steps) as progress_bar:

0 commit comments

Comments
 (0)