@@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
902902    tokenizer ,
903903    prompt : str ,
904904    device = None ,
905+     text_input_ids = None ,
905906    num_images_per_prompt : int  =  1 ,
906907):
907908    prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
908909    batch_size  =  len (prompt )
909910
910-     text_inputs  =  tokenizer (
911-         prompt ,
912-         padding = "max_length" ,
913-         max_length = 77 ,
914-         truncation = True ,
915-         return_tensors = "pt" ,
916-     )
911+     if  tokenizer  is  not   None :
912+         text_inputs  =  tokenizer (
913+             prompt ,
914+             padding = "max_length" ,
915+             max_length = 77 ,
916+             truncation = True ,
917+             return_tensors = "pt" ,
918+         )
919+ 
920+         text_input_ids  =  text_inputs .input_ids 
921+     else :
922+         if  text_input_ids  is  None :
923+             raise  ValueError ("text_input_ids must be provided when the tokenizer is not specified" )
917924
918-     text_input_ids  =  text_inputs .input_ids 
919925    prompt_embeds  =  text_encoder (text_input_ids .to (device ), output_hidden_states = True )
920926
921927    pooled_prompt_embeds  =  prompt_embeds [0 ]
@@ -937,6 +943,7 @@ def encode_prompt(
937943    max_sequence_length ,
938944    device = None ,
939945    num_images_per_prompt : int  =  1 ,
946+     text_input_ids_list = None ,
940947):
941948    prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
942949
@@ -945,13 +952,14 @@ def encode_prompt(
945952
946953    clip_prompt_embeds_list  =  []
947954    clip_pooled_prompt_embeds_list  =  []
948-     for  tokenizer , text_encoder  in  zip (clip_tokenizers , clip_text_encoders ):
955+     for  i , ( tokenizer , text_encoder )  in  enumerate ( zip (clip_tokenizers , clip_text_encoders ) ):
949956        prompt_embeds , pooled_prompt_embeds  =  _encode_prompt_with_clip (
950957            text_encoder = text_encoder ,
951958            tokenizer = tokenizer ,
952959            prompt = prompt ,
953960            device = device  if  device  is  not   None  else  text_encoder .device ,
954961            num_images_per_prompt = num_images_per_prompt ,
962+             text_input_ids = text_input_ids_list [i ] if  text_input_ids_list  else  None ,
955963        )
956964        clip_prompt_embeds_list .append (prompt_embeds )
957965        clip_pooled_prompt_embeds_list .append (pooled_prompt_embeds )
0 commit comments