2626)
2727
2828from ...callbacks import MultiPipelineCallbacks , PipelineCallback
29+ from ...image_processor import PipelineImageInput
2930from ...loaders import HunyuanVideoLoraLoaderMixin
3031from ...models import AutoencoderKLHunyuanVideo , HunyuanVideoTransformer3DModel
3132from ...schedulers import FlowMatchEulerDiscreteScheduler
7576
7677DEFAULT_PROMPT_TEMPLATE = {
7778 "template" : (
78- "<|start_header_id|>system<|end_header_id|>\n \n Describe the video by detailing the following aspects: "
79+ "<|start_header_id|>system<|end_header_id|>\n \n <image> \ n Describe the video by detailing the following aspects according to the reference image : "
7980 "1. The main content and theme of the video."
8081 "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
8182 "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
8283 "4. background environment, light, style and atmosphere."
83- "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
84+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n \n "
8485 "<|start_header_id|>user<|end_header_id|>\n \n {}<|eot_id|>"
86+ "<|start_header_id|>assistant<|end_header_id|>\n \n "
8587 ),
86- "crop_start" : 95 ,
88+ "crop_start" : 103 ,
89+ "image_emb_start" : 5 ,
90+ "image_emb_end" : 581 ,
91+ "image_emb_len" : 576 ,
92+ "double_return_token_id" : 271 ,
8793}
8894
8995
@@ -147,6 +153,20 @@ def retrieve_timesteps(
147153 return timesteps , num_inference_steps
148154
149155
156+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
157+ def retrieve_latents (
158+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
159+ ):
160+ if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
161+ return encoder_output .latent_dist .sample (generator )
162+ elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
163+ return encoder_output .latent_dist .mode ()
164+ elif hasattr (encoder_output , "latents" ):
165+ return encoder_output .latents
166+ else :
167+ raise AttributeError ("Could not access latents of provided encoder_output" )
168+
169+
150170class HunyuanVideoImageToVideoPipeline (DiffusionPipeline , HunyuanVideoLoraLoaderMixin ):
151171 r"""
152172 Pipeline for image-to-video generation using HunyuanVideo.
@@ -197,6 +217,7 @@ def __init__(
197217 scheduler = scheduler ,
198218 text_encoder_2 = text_encoder_2 ,
199219 tokenizer_2 = tokenizer_2 ,
220+ image_processor = image_processor ,
200221 )
201222
202223 self .vae_scale_factor_temporal = self .vae .temporal_compression_ratio if getattr (self , "vae" , None ) else 4
@@ -205,13 +226,15 @@ def __init__(
205226
206227 def _get_llama_prompt_embeds (
207228 self ,
229+ image : torch .Tensor ,
208230 prompt : Union [str , List [str ]],
209231 prompt_template : Dict [str , Any ],
210232 num_videos_per_prompt : int = 1 ,
211233 device : Optional [torch .device ] = None ,
212234 dtype : Optional [torch .dtype ] = None ,
213235 max_sequence_length : int = 256 ,
214236 num_hidden_layers_to_skip : int = 2 ,
237+ image_embed_interleave : int = 2 ,
215238 ) -> Tuple [torch .Tensor , torch .Tensor ]:
216239 device = device or self ._execution_device
217240 dtype = dtype or self .text_encoder .dtype
@@ -232,8 +255,8 @@ def _get_llama_prompt_embeds(
232255 return_attention_mask = False ,
233256 )
234257 crop_start = prompt_template_input ["input_ids" ].shape [- 1 ]
235- # Remove <|eot_id|> token and placeholder {}
236- crop_start -= 2
258+ # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {}
259+ crop_start -= 5
237260
238261 max_sequence_length += crop_start
239262 text_inputs = self .tokenizer (
@@ -249,16 +272,84 @@ def _get_llama_prompt_embeds(
249272 text_input_ids = text_inputs .input_ids .to (device = device )
250273 prompt_attention_mask = text_inputs .attention_mask .to (device = device )
251274
275+ image_embeds = self .image_processor (image , return_tensors = "pt" ).pixel_values .to (device )
276+
252277 prompt_embeds = self .text_encoder (
253278 input_ids = text_input_ids ,
254279 attention_mask = prompt_attention_mask ,
280+ pixel_values = image_embeds ,
255281 output_hidden_states = True ,
256282 ).hidden_states [- (num_hidden_layers_to_skip + 1 )]
257283 prompt_embeds = prompt_embeds .to (dtype = dtype )
258284
285+ image_emb_len = prompt_template .get ("image_emb_len" , 576 )
286+ image_emb_start = prompt_template .get ("image_emb_start" , 5 )
287+ image_emb_end = prompt_template .get ("image_emb_end" , 581 )
288+ double_return_token_id = prompt_template .get ("double_return_token_id" , 271 )
289+
259290 if crop_start is not None and crop_start > 0 :
260- prompt_embeds = prompt_embeds [:, crop_start :]
261- prompt_attention_mask = prompt_attention_mask [:, crop_start :]
291+ text_crop_start = crop_start - 1 + image_emb_len
292+ batch_indices , last_double_return_token_indices = torch .where (text_input_ids == double_return_token_id )
293+
294+ if last_double_return_token_indices .shape [0 ] == 3 :
295+ # in case the prompt is too long
296+ last_double_return_token_indices = torch .cat (
297+ (last_double_return_token_indices , torch .tensor ([text_input_ids .shape [- 1 ]]))
298+ )
299+ batch_indices = torch .cat ((batch_indices , torch .tensor ([0 ])))
300+
301+ last_double_return_token_indices = last_double_return_token_indices .reshape (text_input_ids .shape [0 ], - 1 )[
302+ :, - 1
303+ ]
304+ batch_indices = batch_indices .reshape (text_input_ids .shape [0 ], - 1 )[:, - 1 ]
305+ assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
306+ assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
307+ attention_mask_assistant_crop_start = last_double_return_token_indices - 4
308+ attention_mask_assistant_crop_end = last_double_return_token_indices
309+
310+ prompt_embed_list = []
311+ prompt_attention_mask_list = []
312+ image_embed_list = []
313+ image_attention_mask_list = []
314+
315+ for i in range (text_input_ids .shape [0 ]):
316+ prompt_embed_list .append (
317+ torch .cat (
318+ [
319+ prompt_embeds [i , text_crop_start : assistant_crop_start [i ].item ()],
320+ prompt_embeds [i , assistant_crop_end [i ].item () :],
321+ ]
322+ )
323+ )
324+ prompt_attention_mask_list .append (
325+ torch .cat (
326+ [
327+ prompt_attention_mask [i , crop_start : attention_mask_assistant_crop_start [i ].item ()],
328+ prompt_attention_mask [i , attention_mask_assistant_crop_end [i ].item () :],
329+ ]
330+ )
331+ )
332+ image_embed_list .append (prompt_embeds [i , image_emb_start :image_emb_end ])
333+ image_attention_mask_list .append (
334+ torch .ones (image_embed_list [- 1 ].shape [0 ]).to (prompt_embeds .device ).to (prompt_attention_mask .dtype )
335+ )
336+
337+ prompt_embed_list = torch .stack (prompt_embed_list )
338+ prompt_attention_mask_list = torch .stack (prompt_attention_mask_list )
339+ image_embed_list = torch .stack (image_embed_list )
340+ image_attention_mask_list = torch .stack (image_attention_mask_list )
341+
342+ if image_embed_interleave < 6 :
343+ image_embed_list = image_embed_list [:, ::image_embed_interleave , :]
344+ image_attention_mask_list = image_attention_mask_list [:, ::image_embed_interleave ]
345+
346+ assert (
347+ prompt_embed_list .shape [0 ] == prompt_attention_mask_list .shape [0 ]
348+ and image_embed_list .shape [0 ] == image_attention_mask_list .shape [0 ]
349+ )
350+
351+ prompt_embeds = torch .cat ([image_embed_list , prompt_embed_list ], dim = 1 )
352+ prompt_attention_mask = torch .cat ([image_attention_mask_list , prompt_attention_mask_list ], dim = 1 )
262353
263354 # duplicate text embeddings for each generation per prompt, using mps friendly method
264355 _ , seq_len , _ = prompt_embeds .shape
@@ -310,6 +401,7 @@ def _get_clip_prompt_embeds(
310401
311402 def encode_prompt (
312403 self ,
404+ image : torch .Tensor ,
313405 prompt : Union [str , List [str ]],
314406 prompt_2 : Union [str , List [str ]] = None ,
315407 prompt_template : Dict [str , Any ] = DEFAULT_PROMPT_TEMPLATE ,
@@ -323,6 +415,7 @@ def encode_prompt(
323415 ):
324416 if prompt_embeds is None :
325417 prompt_embeds , prompt_attention_mask = self ._get_llama_prompt_embeds (
418+ image ,
326419 prompt ,
327420 prompt_template ,
328421 num_videos_per_prompt ,
@@ -393,6 +486,7 @@ def check_inputs(
393486
394487 def prepare_latents (
395488 self ,
489+ image : torch .Tensor ,
396490 batch_size : int ,
397491 num_channels_latents : int = 32 ,
398492 height : int = 720 ,
@@ -403,24 +497,36 @@ def prepare_latents(
403497 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
404498 latents : Optional [torch .Tensor ] = None ,
405499 ) -> torch .Tensor :
406- if latents is not None :
407- return latents .to (device = device , dtype = dtype )
408-
409- shape = (
410- batch_size ,
411- num_channels_latents ,
412- (num_frames - 1 ) // self .vae_scale_factor_temporal + 1 ,
413- int (height ) // self .vae_scale_factor_spatial ,
414- int (width ) // self .vae_scale_factor_spatial ,
415- )
416500 if isinstance (generator , list ) and len (generator ) != batch_size :
417501 raise ValueError (
418502 f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
419503 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
420504 )
421505
422- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
423- return latents
506+ num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
507+ latent_height , latent_width = height // self .vae_scale_factor_spatial , width // self .vae_scale_factor_spatial
508+ shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
509+
510+ image = image .unsqueeze (2 ) # [B, C, 1, H, W]
511+ if isinstance (generator , list ):
512+ image_latents = [
513+ retrieve_latents (self .vae .encode (image [i ].unsqueeze (0 )), generator [i ]) for i in range (batch_size )
514+ ]
515+ else :
516+ image_latents = [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for img in image ]
517+
518+ image_latents = torch .cat (image_latents , dim = 0 ).to (dtype ) * self .vae_scaling_factor
519+ image_latents = image_latents .repeat (1 , 1 , num_latent_frames , 1 , 1 )
520+
521+ if latents is None :
522+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
523+ else :
524+ latents = latents .to (device = device , dtype = dtype )
525+
526+ t = torch .tensor ([0.999 ]).to (device = device )
527+ latents = latents * t + image_latents * (1 - t )
528+
529+ return latents , image_latents
424530
425531 def enable_vae_slicing (self ):
426532 r"""
@@ -475,6 +581,7 @@ def interrupt(self):
475581 @replace_example_docstring (EXAMPLE_DOC_STRING )
476582 def __call__ (
477583 self ,
584+ image : PipelineImageInput ,
478585 prompt : Union [str , List [str ]] = None ,
479586 prompt_2 : Union [str , List [str ]] = None ,
480587 negative_prompt : Union [str , List [str ]] = None ,
@@ -632,9 +739,30 @@ def __call__(
632739 else :
633740 batch_size = prompt_embeds .shape [0 ]
634741
635- # 3. Encode input prompt
742+ # 3. Prepare latent variables
743+ vae_dtype = self .vae .dtype
744+ image = self .video_processor .preprocess (image , height , width ).to (device , vae_dtype )
745+ num_channels_latents = (self .transformer .config .in_channels - 1 ) // 2
746+ latents , image_latents = self .prepare_latents (
747+ image ,
748+ batch_size * num_videos_per_prompt ,
749+ num_channels_latents ,
750+ height ,
751+ width ,
752+ num_frames ,
753+ torch .float32 ,
754+ device ,
755+ generator ,
756+ latents ,
757+ )
758+ image_latents [:, :, 1 :] = 0
759+ mask = image_latents .new_ones (image_latents .shape [0 ], 1 , * image_latents .shape [2 :])
760+ mask [:, :, 1 :] = 0
761+
762+ # 4. Encode input prompt
636763 transformer_dtype = self .transformer .dtype
637764 prompt_embeds , pooled_prompt_embeds , prompt_attention_mask = self .encode_prompt (
765+ image = image ,
638766 prompt = prompt ,
639767 prompt_2 = prompt_2 ,
640768 prompt_template = prompt_template ,
@@ -651,6 +779,7 @@ def __call__(
651779
652780 if do_true_cfg :
653781 negative_prompt_embeds , negative_pooled_prompt_embeds , negative_prompt_attention_mask = self .encode_prompt (
782+ image = torch .full_like (image , fill_value = - 1 ),
654783 prompt = negative_prompt ,
655784 prompt_2 = negative_prompt_2 ,
656785 prompt_template = prompt_template ,
@@ -669,23 +798,6 @@ def __call__(
669798 sigmas = np .linspace (1.0 , 0.0 , num_inference_steps + 1 )[:- 1 ] if sigmas is None else sigmas
670799 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas )
671800
672- # 5. Prepare latent variables
673- num_channels_latents = self .transformer .config .in_channels
674- latents = self .prepare_latents (
675- batch_size * num_videos_per_prompt ,
676- num_channels_latents ,
677- height ,
678- width ,
679- num_frames ,
680- torch .float32 ,
681- device ,
682- generator ,
683- latents ,
684- )
685-
686- # 6. Prepare guidance condition
687- guidance = torch .tensor ([guidance_scale ] * latents .shape [0 ], dtype = transformer_dtype , device = device ) * 1000.0
688-
689801 # 7. Denoising loop
690802 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
691803 self ._num_timesteps = len (timesteps )
@@ -696,7 +808,7 @@ def __call__(
696808 continue
697809
698810 self ._current_timestep = t
699- latent_model_input = latents . to ( transformer_dtype )
811+ latent_model_input = torch . cat ([ latents , image_latents , mask ], dim = 1 )
700812 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
701813 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
702814
@@ -706,7 +818,6 @@ def __call__(
706818 encoder_hidden_states = prompt_embeds ,
707819 encoder_attention_mask = prompt_attention_mask ,
708820 pooled_projections = pooled_prompt_embeds ,
709- guidance = guidance ,
710821 attention_kwargs = attention_kwargs ,
711822 return_dict = False ,
712823 )[0 ]
@@ -718,7 +829,6 @@ def __call__(
718829 encoder_hidden_states = negative_prompt_embeds ,
719830 encoder_attention_mask = negative_prompt_attention_mask ,
720831 pooled_projections = negative_pooled_prompt_embeds ,
721- guidance = guidance ,
722832 attention_kwargs = attention_kwargs ,
723833 return_dict = False ,
724834 )[0 ]
0 commit comments