1616from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1717
1818import numpy as np
19+ import PIL .Image
1920import torch
2021from transformers import (
2122 CLIPImageProcessor ,
2627)
2728
2829from ...callbacks import MultiPipelineCallbacks , PipelineCallback
29- from ...image_processor import PipelineImageInput
3030from ...loaders import HunyuanVideoLoraLoaderMixin
3131from ...models import AutoencoderKLHunyuanVideo , HunyuanVideoTransformer3DModel
3232from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -220,6 +220,7 @@ def __init__(
220220 image_processor = image_processor ,
221221 )
222222
223+ self .vae_scaling_factor = self .vae .config .scaling_factor if getattr (self , "vae" , None ) else 0.476986
223224 self .vae_scale_factor_temporal = self .vae .temporal_compression_ratio if getattr (self , "vae" , None ) else 4
224225 self .vae_scale_factor_spatial = self .vae .spatial_compression_ratio if getattr (self , "vae" , None ) else 8
225226 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
@@ -240,8 +241,6 @@ def _get_llama_prompt_embeds(
240241 dtype = dtype or self .text_encoder .dtype
241242
242243 prompt = [prompt ] if isinstance (prompt , str ) else prompt
243- batch_size = len (prompt )
244-
245244 prompt = [prompt_template ["template" ].format (p ) for p in prompt ]
246245
247246 crop_start = prompt_template .get ("crop_start" , None )
@@ -351,13 +350,6 @@ def _get_llama_prompt_embeds(
351350 prompt_embeds = torch .cat ([image_embed_list , prompt_embed_list ], dim = 1 )
352351 prompt_attention_mask = torch .cat ([image_attention_mask_list , prompt_attention_mask_list ], dim = 1 )
353352
354- # duplicate text embeddings for each generation per prompt, using mps friendly method
355- _ , seq_len , _ = prompt_embeds .shape
356- prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
357- prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
358- prompt_attention_mask = prompt_attention_mask .repeat (1 , num_videos_per_prompt )
359- prompt_attention_mask = prompt_attention_mask .view (batch_size * num_videos_per_prompt , seq_len )
360-
361353 return prompt_embeds , prompt_attention_mask
362354
363355 def _get_clip_prompt_embeds (
@@ -372,7 +364,6 @@ def _get_clip_prompt_embeds(
372364 dtype = dtype or self .text_encoder_2 .dtype
373365
374366 prompt = [prompt ] if isinstance (prompt , str ) else prompt
375- batch_size = len (prompt )
376367
377368 text_inputs = self .tokenizer_2 (
378369 prompt ,
@@ -392,11 +383,6 @@ def _get_clip_prompt_embeds(
392383 )
393384
394385 prompt_embeds = self .text_encoder_2 (text_input_ids .to (device ), output_hidden_states = False ).pooler_output
395-
396- # duplicate text embeddings for each generation per prompt, using mps friendly method
397- prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt )
398- prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , - 1 )
399-
400386 return prompt_embeds
401387
402388 def encode_prompt (
@@ -581,7 +567,7 @@ def interrupt(self):
581567 @replace_example_docstring (EXAMPLE_DOC_STRING )
582568 def __call__ (
583569 self ,
584- image : PipelineImageInput ,
570+ image : PIL . Image . Image ,
585571 prompt : Union [str , List [str ]] = None ,
586572 prompt_2 : Union [str , List [str ]] = None ,
587573 negative_prompt : Union [str , List [str ]] = None ,
@@ -741,10 +727,10 @@ def __call__(
741727
742728 # 3. Prepare latent variables
743729 vae_dtype = self .vae .dtype
744- image = self .video_processor .preprocess (image , height , width ).to (device , vae_dtype )
730+ image_tensor = self .video_processor .preprocess (image , height , width ).to (device , vae_dtype )
745731 num_channels_latents = (self .transformer .config .in_channels - 1 ) // 2
746732 latents , image_latents = self .prepare_latents (
747- image ,
733+ image_tensor ,
748734 batch_size * num_videos_per_prompt ,
749735 num_channels_latents ,
750736 height ,
@@ -778,8 +764,9 @@ def __call__(
778764 pooled_prompt_embeds = pooled_prompt_embeds .to (transformer_dtype )
779765
780766 if do_true_cfg :
767+ black_image = PIL .Image .new ("RGB" , (width , height ), 0 )
781768 negative_prompt_embeds , negative_pooled_prompt_embeds , negative_prompt_attention_mask = self .encode_prompt (
782- image = torch . full_like ( image , fill_value = - 1 ) ,
769+ image = black_image ,
783770 prompt = negative_prompt ,
784771 prompt_2 = negative_prompt_2 ,
785772 prompt_template = prompt_template ,
@@ -808,7 +795,7 @@ def __call__(
808795 continue
809796
810797 self ._current_timestep = t
811- latent_model_input = torch .cat ([latents , image_latents , mask ], dim = 1 )
798+ latent_model_input = torch .cat ([latents , image_latents , mask ], dim = 1 ). to ( transformer_dtype )
812799 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
813800 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
814801
0 commit comments