@@ -573,6 +573,13 @@ def parse_args(input_args=None):
573
573
default = 1e-4 ,
574
574
help = "Initial learning rate (after the potential warmup period) to use." ,
575
575
)
576
+ parser .add_argument (
577
+ "--clip_skip" ,
578
+ type = int ,
579
+ default = None ,
580
+ help = "Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
581
+ "the output of the pre-final layer will be used for computing the prompt embeddings." ,
582
+ )
576
583
577
584
parser .add_argument (
578
585
"--text_encoder_lr" ,
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
1236
1243
1237
1244
1238
1245
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
1239
- def encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None ):
1246
+ def encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None , clip_skip = None ):
1240
1247
prompt_embeds_list = []
1241
1248
1242
1249
for i , text_encoder in enumerate (text_encoders ):
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
1253
1260
1254
1261
# We are only ALWAYS interested in the pooled output of the final text encoder
1255
1262
pooled_prompt_embeds = prompt_embeds [0 ]
1256
- prompt_embeds = prompt_embeds [- 1 ][- 2 ]
1263
+ if clip_skip is None :
1264
+ prompt_embeds = prompt_embeds [- 1 ][- 2 ]
1265
+ else :
1266
+ # "2" because SDXL always indexes from the penultimate layer.
1267
+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 2 )]
1257
1268
bs_embed , seq_len , _ = prompt_embeds .shape
1258
1269
prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
1259
1270
prompt_embeds_list .append (prompt_embeds )
@@ -1830,9 +1841,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None):
1830
1841
tokenizers = [tokenizer_one , tokenizer_two ]
1831
1842
text_encoders = [text_encoder_one , text_encoder_two ]
1832
1843
1833
- def compute_text_embeddings (prompt , text_encoders , tokenizers ):
1844
+ def compute_text_embeddings (prompt , text_encoders , tokenizers , clip_skip ):
1834
1845
with torch .no_grad ():
1835
- prompt_embeds , pooled_prompt_embeds = encode_prompt (text_encoders , tokenizers , prompt )
1846
+ prompt_embeds , pooled_prompt_embeds = encode_prompt (text_encoders , tokenizers , prompt , clip_skip )
1836
1847
prompt_embeds = prompt_embeds .to (accelerator .device )
1837
1848
pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
1838
1849
return prompt_embeds , pooled_prompt_embeds
@@ -1842,7 +1853,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1842
1853
# the redundant encoding.
1843
1854
if freeze_text_encoder and not train_dataset .custom_instance_prompts :
1844
1855
instance_prompt_hidden_states , instance_pooled_prompt_embeds = compute_text_embeddings (
1845
- args .instance_prompt , text_encoders , tokenizers
1856
+ args .instance_prompt , text_encoders , tokenizers , args . clip_skip
1846
1857
)
1847
1858
1848
1859
# Handle class prompt for prior-preservation.
@@ -2052,7 +2063,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2052
2063
if train_dataset .custom_instance_prompts :
2053
2064
if freeze_text_encoder :
2054
2065
prompt_embeds , unet_add_text_embeds = compute_text_embeddings (
2055
- prompts , text_encoders , tokenizers
2066
+ prompts , text_encoders , tokenizers , args . clip_skip
2056
2067
)
2057
2068
2058
2069
else :
@@ -2147,6 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2147
2158
tokenizers = None ,
2148
2159
prompt = None ,
2149
2160
text_input_ids_list = [tokens_one , tokens_two ],
2161
+ clip_skip = args .clip_skip ,
2150
2162
)
2151
2163
unet_added_conditions .update (
2152
2164
{"text_embeds" : pooled_prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 )}
0 commit comments