Skip to content

Commit 3613b23

Browse files
committed
update
1 parent 99e8da1 commit 3613b23

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

src/diffusers/pipelines/hunyuan_video/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
2526
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
2627

2728
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def encode_prompt(
345345
)
346346

347347
if pooled_prompt_embeds is None:
348-
if prompt_2 is None and pooled_prompt_embeds is None:
348+
if prompt_2 is None:
349349
prompt_2 = prompt
350350
pooled_prompt_embeds = self._get_clip_prompt_embeds(
351351
prompt,
@@ -424,13 +424,16 @@ def prepare_latents(
424424
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
425425
)
426426

427+
image = image.unsqueeze(2) # [B, C, 1, H, W]
427428
if isinstance(generator, list):
428429
image_latents = [
429430
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
430431
]
431432
else:
432433
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
433434

435+
image_latents = torch.cat(image_latents, dim=0).to(dtype)
436+
434437
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
435438
latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
436439
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
@@ -502,18 +505,24 @@ def __call__(
502505
image: PipelineImageInput,
503506
prompt: Union[str, List[str]] = None,
504507
prompt_2: Union[str, List[str]] = None,
508+
negative_prompt: Union[str, List[str]] = None,
509+
negative_prompt_2: Union[str, List[str]] = None,
505510
height: int = 544,
506511
width: int = 960,
507512
num_frames: int = 97,
508513
num_inference_steps: int = 50,
509514
sigmas: List[float] = None,
510-
guidance_scale: float = 6.0,
515+
true_cfg_scale: float = 6.0,
516+
guidance_scale: float = 1.0,
511517
num_videos_per_prompt: Optional[int] = 1,
512518
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
513519
latents: Optional[torch.Tensor] = None,
514520
prompt_embeds: Optional[torch.Tensor] = None,
515521
pooled_prompt_embeds: Optional[torch.Tensor] = None,
516522
prompt_attention_mask: Optional[torch.Tensor] = None,
523+
negative_prompt_embeds: Optional[torch.Tensor] = None,
524+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
525+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
517526
output_type: Optional[str] = "pil",
518527
return_dict: bool = True,
519528
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -534,6 +543,13 @@ def __call__(
534543
prompt_2 (`str` or `List[str]`, *optional*):
535544
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
536545
will be used instead.
546+
negative_prompt (`str` or `List[str]`, *optional*):
547+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
548+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
549+
not greater than `1`).
550+
negative_prompt_2 (`str` or `List[str]`, *optional*):
551+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
552+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
537553
height (`int`, defaults to `720`):
538554
The height in pixels of the generated image.
539555
width (`int`, defaults to `1280`):
@@ -547,6 +563,8 @@ def __call__(
547563
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
548564
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
549565
will be used.
566+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
567+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
550568
guidance_scale (`float`, defaults to `6.0`):
551569
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
552570
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -567,6 +585,17 @@ def __call__(
567585
prompt_embeds (`torch.Tensor`, *optional*):
568586
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
569587
provided, text embeddings are generated from the `prompt` input argument.
588+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
589+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
590+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
591+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
592+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
593+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
594+
argument.
595+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
596+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
597+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
598+
input argument.
570599
output_type (`str`, *optional*, defaults to `"pil"`):
571600
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
572601
return_dict (`bool`, *optional*, defaults to `True`):
@@ -611,6 +640,11 @@ def __call__(
611640
prompt_template,
612641
)
613642

643+
has_neg_prompt = negative_prompt is not None or (
644+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
645+
)
646+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
647+
614648
self._guidance_scale = guidance_scale
615649
self._attention_kwargs = attention_kwargs
616650
self._current_timestep = None
@@ -627,6 +661,7 @@ def __call__(
627661
batch_size = prompt_embeds.shape[0]
628662

629663
# 3. Encode input prompt
664+
transformer_dtype = self.transformer.dtype
630665
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
631666
prompt=prompt,
632667
prompt_2=prompt_2,
@@ -638,21 +673,29 @@ def __call__(
638673
device=device,
639674
max_sequence_length=max_sequence_length,
640675
)
641-
642-
transformer_dtype = self.transformer.dtype
643676
prompt_embeds = prompt_embeds.to(transformer_dtype)
644677
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
645-
if pooled_prompt_embeds is not None:
646-
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
678+
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
679+
680+
if do_true_cfg:
681+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
682+
prompt=negative_prompt,
683+
prompt_2=negative_prompt_2,
684+
prompt_template=prompt_template,
685+
num_videos_per_prompt=num_videos_per_prompt,
686+
prompt_embeds=negative_prompt_embeds,
687+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
688+
prompt_attention_mask=negative_prompt_attention_mask,
689+
device=device,
690+
max_sequence_length=max_sequence_length,
691+
)
692+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
693+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
694+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
647695

648696
# 4. Prepare timesteps
649697
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
650-
timesteps, num_inference_steps = retrieve_timesteps(
651-
self.scheduler,
652-
num_inference_steps,
653-
device,
654-
sigmas=sigmas,
655-
)
698+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
656699

657700
# 5. Prepare latent variables
658701
vae_dtype = self.vae.dtype
@@ -702,6 +745,19 @@ def __call__(
702745
return_dict=False,
703746
)[0]
704747

748+
if do_true_cfg:
749+
neg_noise_pred = self.transformer(
750+
hidden_states=latent_model_input,
751+
timestep=timestep,
752+
encoder_hidden_states=negative_prompt_embeds,
753+
encoder_attention_mask=negative_prompt_attention_mask,
754+
pooled_projections=negative_pooled_prompt_embeds,
755+
guidance=guidance,
756+
attention_kwargs=attention_kwargs,
757+
return_dict=False,
758+
)[0]
759+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
760+
705761
# compute the previous noisy sample x_t -> x_t-1
706762
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
707763

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def encode_prompt(
325325
)
326326

327327
if pooled_prompt_embeds is None:
328-
if prompt_2 is None and pooled_prompt_embeds is None:
328+
if prompt_2 is None:
329329
prompt_2 = prompt
330330
pooled_prompt_embeds = self._get_clip_prompt_embeds(
331331
prompt,
@@ -470,18 +470,24 @@ def __call__(
470470
self,
471471
prompt: Union[str, List[str]] = None,
472472
prompt_2: Union[str, List[str]] = None,
473+
negative_prompt: Union[str, List[str]] = None,
474+
negative_prompt_2: Union[str, List[str]] = None,
473475
height: int = 720,
474476
width: int = 1280,
475477
num_frames: int = 129,
476478
num_inference_steps: int = 50,
477479
sigmas: List[float] = None,
480+
true_cfg_scale: float = 1.0,
478481
guidance_scale: float = 6.0,
479482
num_videos_per_prompt: Optional[int] = 1,
480483
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
481484
latents: Optional[torch.Tensor] = None,
482485
prompt_embeds: Optional[torch.Tensor] = None,
483486
pooled_prompt_embeds: Optional[torch.Tensor] = None,
484487
prompt_attention_mask: Optional[torch.Tensor] = None,
488+
negative_prompt_embeds: Optional[torch.Tensor] = None,
489+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
490+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
485491
output_type: Optional[str] = "pil",
486492
return_dict: bool = True,
487493
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -502,6 +508,13 @@ def __call__(
502508
prompt_2 (`str` or `List[str]`, *optional*):
503509
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
504510
will be used instead.
511+
negative_prompt (`str` or `List[str]`, *optional*):
512+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
513+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
514+
not greater than `1`).
515+
negative_prompt_2 (`str` or `List[str]`, *optional*):
516+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
517+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
505518
height (`int`, defaults to `720`):
506519
The height in pixels of the generated image.
507520
width (`int`, defaults to `1280`):
@@ -515,6 +528,8 @@ def __call__(
515528
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
516529
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
517530
will be used.
531+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
532+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
518533
guidance_scale (`float`, defaults to `6.0`):
519534
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
520535
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -535,6 +550,17 @@ def __call__(
535550
prompt_embeds (`torch.Tensor`, *optional*):
536551
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
537552
provided, text embeddings are generated from the `prompt` input argument.
553+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
554+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
555+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
556+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
557+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
558+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
559+
argument.
560+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
561+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
562+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
563+
input argument.
538564
output_type (`str`, *optional*, defaults to `"pil"`):
539565
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
540566
return_dict (`bool`, *optional*, defaults to `True`):
@@ -579,6 +605,11 @@ def __call__(
579605
prompt_template,
580606
)
581607

608+
has_neg_prompt = negative_prompt is not None or (
609+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
610+
)
611+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
612+
582613
self._guidance_scale = guidance_scale
583614
self._attention_kwargs = attention_kwargs
584615
self._current_timestep = None
@@ -595,6 +626,7 @@ def __call__(
595626
batch_size = prompt_embeds.shape[0]
596627

597628
# 3. Encode input prompt
629+
transformer_dtype = self.transformer.dtype
598630
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
599631
prompt=prompt,
600632
prompt_2=prompt_2,
@@ -606,21 +638,29 @@ def __call__(
606638
device=device,
607639
max_sequence_length=max_sequence_length,
608640
)
609-
610-
transformer_dtype = self.transformer.dtype
611641
prompt_embeds = prompt_embeds.to(transformer_dtype)
612642
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
613-
if pooled_prompt_embeds is not None:
614-
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
643+
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
644+
645+
if do_true_cfg:
646+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
647+
prompt=negative_prompt,
648+
prompt_2=negative_prompt_2,
649+
prompt_template=prompt_template,
650+
num_videos_per_prompt=num_videos_per_prompt,
651+
prompt_embeds=negative_prompt_embeds,
652+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
653+
prompt_attention_mask=negative_prompt_attention_mask,
654+
device=device,
655+
max_sequence_length=max_sequence_length,
656+
)
657+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
658+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
659+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
615660

616661
# 4. Prepare timesteps
617662
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
618-
timesteps, num_inference_steps = retrieve_timesteps(
619-
self.scheduler,
620-
num_inference_steps,
621-
device,
622-
sigmas=sigmas,
623-
)
663+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
624664

625665
# 5. Prepare latent variables
626666
num_channels_latents = self.transformer.config.in_channels
@@ -665,6 +705,19 @@ def __call__(
665705
return_dict=False,
666706
)[0]
667707

708+
if do_true_cfg:
709+
neg_noise_pred = self.transformer(
710+
hidden_states=latent_model_input,
711+
timestep=timestep,
712+
encoder_hidden_states=negative_prompt_embeds,
713+
encoder_attention_mask=negative_prompt_attention_mask,
714+
pooled_projections=negative_pooled_prompt_embeds,
715+
guidance=guidance,
716+
attention_kwargs=attention_kwargs,
717+
return_dict=False,
718+
)[0]
719+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
720+
668721
# compute the previous noisy sample x_t -> x_t-1
669722
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
670723

0 commit comments

Comments
 (0)