Skip to content

Commit ba98835

Browse files
committed
feat: enable true cfg in hunyuanvideo.
1 parent 4a4afd5 commit ba98835

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,12 +466,15 @@ def __call__(
466466
self,
467467
prompt: Union[str, List[str]] = None,
468468
prompt_2: Union[str, List[str]] = None,
469+
negative_prompt: Union[str, List[str]] = None,
470+
negative_prompt_2: Union[str, List[str]] = None,
469471
height: int = 720,
470472
width: int = 1280,
471473
num_frames: int = 129,
472474
num_inference_steps: int = 50,
473475
sigmas: List[float] = None,
474476
guidance_scale: float = 6.0,
477+
true_cfg_scale: float = 1.0,
475478
num_videos_per_prompt: Optional[int] = 1,
476479
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
477480
latents: Optional[torch.Tensor] = None,
@@ -590,6 +593,7 @@ def __call__(
590593
batch_size = prompt_embeds.shape[0]
591594

592595
# 3. Encode input prompt
596+
do_true_cfg = true_cfg_scale > 1.0 and negative_prompt is not None
593597
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
594598
prompt=prompt,
595599
prompt_2=prompt_2,
@@ -601,12 +605,29 @@ def __call__(
601605
device=device,
602606
max_sequence_length=max_sequence_length,
603607
)
608+
if do_true_cfg:
609+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
610+
prompt=negative_prompt,
611+
prompt_2=negative_prompt_2,
612+
prompt_template=prompt_template,
613+
num_videos_per_prompt=num_videos_per_prompt,
614+
prompt_embeds=None,
615+
pooled_prompt_embeds=None,
616+
prompt_attention_mask=None,
617+
device=device,
618+
max_sequence_length=max_sequence_length,
619+
)
604620

605621
transformer_dtype = self.transformer.dtype
606622
prompt_embeds = prompt_embeds.to(transformer_dtype)
607623
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
608624
if pooled_prompt_embeds is not None:
609625
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
626+
if do_true_cfg:
627+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
628+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
629+
if negative_pooled_prompt_embeds is not None:
630+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
610631

611632
# 4. Prepare timesteps
612633
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
@@ -658,6 +679,18 @@ def __call__(
658679
attention_kwargs=attention_kwargs,
659680
return_dict=False,
660681
)[0]
682+
if do_true_cfg:
683+
neg_noise_pred = self.transformer(
684+
hidden_states=latent_model_input,
685+
timestep=timestep,
686+
encoder_hidden_states=negative_prompt_embeds,
687+
encoder_attention_mask=negative_prompt_attention_mask,
688+
pooled_projections=negative_pooled_prompt_embeds,
689+
guidance=guidance,
690+
attention_kwargs=attention_kwargs,
691+
return_dict=False,
692+
)[0]
693+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
661694

662695
# compute the previous noisy sample x_t -> x_t-1
663696
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)