Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,15 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Union[str, List[str]] = None,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
num_inference_steps: int = 50,
sigmas: List[float] = None,
guidance_scale: float = 6.0,
true_cfg_scale: float = 1.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -590,6 +593,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]

# 3. Encode input prompt
do_true_cfg = true_cfg_scale > 1.0 and negative_prompt is not None
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
Expand All @@ -601,12 +605,29 @@ def __call__(
device=device,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_template=prompt_template,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=None,
pooled_prompt_embeds=None,
prompt_attention_mask=None,
device=device,
max_sequence_length=max_sequence_length,
)

transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
if do_true_cfg:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
if negative_pooled_prompt_embeds is not None:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)

# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
Expand Down Expand Up @@ -658,6 +679,18 @@ def __call__(
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
Expand Down