@@ -573,6 +573,13 @@ def parse_args(input_args=None):
573573        default = 1e-4 ,
574574        help = "Initial learning rate (after the potential warmup period) to use." ,
575575    )
576+     parser .add_argument (
577+         "--clip_skip" ,
578+         type = int ,
579+         default = None ,
580+         help = "Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that " 
581+         "the output of the pre-final layer will be used for computing the prompt embeddings." ,
582+     )
576583
577584    parser .add_argument (
578585        "--text_encoder_lr" ,
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
12361243
12371244
12381245# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 
1239- def  encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None ):
1246+ def  encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None ,  clip_skip = None ):
12401247    prompt_embeds_list  =  []
12411248
12421249    for  i , text_encoder  in  enumerate (text_encoders ):
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
12531260
12541261        # We are only ALWAYS interested in the pooled output of the final text encoder 
12551262        pooled_prompt_embeds  =  prompt_embeds [0 ]
1256-         prompt_embeds  =  prompt_embeds [- 1 ][- 2 ]
1263+         if  clip_skip  is  None :
1264+             prompt_embeds  =  prompt_embeds [- 1 ][- 2 ]
1265+         else :
1266+             # "2" because SDXL always indexes from the penultimate layer. 
1267+             prompt_embeds  =  prompt_embeds [- 1 ][- (clip_skip  +  2 )]
12571268        bs_embed , seq_len , _  =  prompt_embeds .shape 
12581269        prompt_embeds  =  prompt_embeds .view (bs_embed , seq_len , - 1 )
12591270        prompt_embeds_list .append (prompt_embeds )
@@ -1830,9 +1841,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None):
18301841        tokenizers  =  [tokenizer_one , tokenizer_two ]
18311842        text_encoders  =  [text_encoder_one , text_encoder_two ]
18321843
1833-         def  compute_text_embeddings (prompt , text_encoders , tokenizers ):
1844+         def  compute_text_embeddings (prompt , text_encoders , tokenizers ,  clip_skip ):
18341845            with  torch .no_grad ():
1835-                 prompt_embeds , pooled_prompt_embeds  =  encode_prompt (text_encoders , tokenizers , prompt )
1846+                 prompt_embeds , pooled_prompt_embeds  =  encode_prompt (text_encoders , tokenizers , prompt ,  clip_skip )
18361847                prompt_embeds  =  prompt_embeds .to (accelerator .device )
18371848                pooled_prompt_embeds  =  pooled_prompt_embeds .to (accelerator .device )
18381849            return  prompt_embeds , pooled_prompt_embeds 
@@ -1842,7 +1853,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18421853    # the redundant encoding. 
18431854    if  freeze_text_encoder  and  not  train_dataset .custom_instance_prompts :
18441855        instance_prompt_hidden_states , instance_pooled_prompt_embeds  =  compute_text_embeddings (
1845-             args .instance_prompt , text_encoders , tokenizers 
1856+             args .instance_prompt , text_encoders , tokenizers ,  args . clip_skip 
18461857        )
18471858
18481859    # Handle class prompt for prior-preservation. 
@@ -2052,7 +2063,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20522063                if  train_dataset .custom_instance_prompts :
20532064                    if  freeze_text_encoder :
20542065                        prompt_embeds , unet_add_text_embeds  =  compute_text_embeddings (
2055-                             prompts , text_encoders , tokenizers 
2066+                             prompts , text_encoders , tokenizers ,  args . clip_skip 
20562067                        )
20572068
20582069                    else :
@@ -2147,6 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21472158                        tokenizers = None ,
21482159                        prompt = None ,
21492160                        text_input_ids_list = [tokens_one , tokens_two ],
2161+                         clip_skip = args .clip_skip ,
21502162                    )
21512163                    unet_added_conditions .update (
21522164                        {"text_embeds" : pooled_prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 )}
0 commit comments