1818
1919import numpy as np
2020import torch
21+ from transformers import CLIPTextModel , CLIPTokenizer
2122
2223from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2324from ...image_processor import VaeImageProcessor
@@ -140,7 +141,8 @@ def __init__(
140141 text_encoder : TextEncoder ,
141142 transformer : HunyuanVideoTransformer3DModel ,
142143 scheduler : KarrasDiffusionSchedulers ,
143- text_encoder_2 : Optional [TextEncoder ] = None ,
144+ text_encoder_2 : Optional [CLIPTextModel ] = None ,
145+ tokenizer_2 : Optional [CLIPTokenizer ] = None ,
144146 ):
145147 super ().__init__ ()
146148
@@ -150,6 +152,7 @@ def __init__(
150152 transformer = transformer ,
151153 scheduler = scheduler ,
152154 text_encoder_2 = text_encoder_2 ,
155+ tokenizer_2 = tokenizer_2 ,
153156 )
154157 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
155158 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
@@ -246,6 +249,45 @@ def encode_prompt(
246249 attention_mask ,
247250 )
248251
252+ def _get_clip_prompt_embeds (
253+ self ,
254+ prompt : Union [str , List [str ]],
255+ num_videos_per_prompt : int = 1 ,
256+ device : Optional [torch .device ] = None ,
257+ dtype : Optional [torch .dtype ] = None ,
258+ max_sequence_length : int = 77 ,
259+ ):
260+ device = device or self ._execution_device
261+ dtype = dtype or self .text_encoder_2 .dtype
262+
263+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
264+ batch_size = len (prompt )
265+
266+ text_inputs = self .tokenizer_2 (
267+ prompt ,
268+ padding = "max_length" ,
269+ max_length = max_sequence_length ,
270+ truncation = True ,
271+ return_tensors = "pt" ,
272+ )
273+
274+ text_input_ids = text_inputs .input_ids
275+ untruncated_ids = self .tokenizer_2 (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
276+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
277+ removed_text = self .tokenizer_2 .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
278+ logger .warning (
279+ "The following part of your input was truncated because CLIP can only handle sequences up to"
280+ f" { max_sequence_length } tokens: { removed_text } "
281+ )
282+
283+ prompt_embeds = self .text_encoder_2 (text_input_ids .to (device ), output_hidden_states = False ).pooler_output
284+
285+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286+ prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt )
287+ prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , - 1 )
288+
289+ return prompt_embeds
290+
249291 def check_inputs (
250292 self ,
251293 prompt ,
@@ -577,7 +619,6 @@ def __call__(
577619 prompt_attention_mask : Optional [torch .Tensor ] = None ,
578620 output_type : Optional [str ] = "pil" ,
579621 return_dict : bool = True ,
580- clip_skip : Optional [int ] = None ,
581622 callback_on_step_end : Optional [
582623 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
583624 ] = None ,
@@ -674,7 +715,6 @@ def __call__(
674715 )
675716
676717 self ._guidance_scale = guidance_scale
677- self ._clip_skip = clip_skip
678718 self ._interrupt = False
679719
680720 # 2. Define call parameters
@@ -698,27 +738,33 @@ def __call__(
698738 num_videos_per_prompt ,
699739 prompt_embeds = prompt_embeds ,
700740 attention_mask = prompt_attention_mask ,
701- clip_skip = self .clip_skip ,
702741 data_type = data_type ,
703742 )
704743
744+ # if self.text_encoder_2 is not None:
745+ # (
746+ # prompt_embeds_2,
747+ # prompt_mask_2,
748+ # ) = self.encode_prompt(
749+ # prompt,
750+ # device,
751+ # num_videos_per_prompt,
752+ # prompt_embeds=prompt_embeds_2,
753+ # attention_mask=None,
754+ # clip_skip=self.clip_skip,
755+ # text_encoder=self.text_encoder_2,
756+ # data_type=data_type,
757+ # )
758+ # else:
759+ # prompt_embeds_2 = None
760+ # prompt_mask_2 = None
761+
705762 if self .text_encoder_2 is not None :
706- (
707- prompt_embeds_2 ,
708- prompt_mask_2 ,
709- ) = self .encode_prompt (
763+ prompt_embeds_2 = self ._get_clip_prompt_embeds (
710764 prompt ,
711- device ,
712765 num_videos_per_prompt ,
713- prompt_embeds = prompt_embeds_2 ,
714- attention_mask = None ,
715- clip_skip = self .clip_skip ,
716- text_encoder = self .text_encoder_2 ,
717- data_type = data_type ,
766+ device = device ,
718767 )
719- else :
720- prompt_embeds_2 = None
721- prompt_mask_2 = None
722768
723769 # 4. Prepare timesteps
724770 timesteps , num_inference_steps = retrieve_timesteps (
0 commit comments