Skip to content

Commit beb1c01

Browse files
[advanced dreambooth lora] add clip_skip arg (#8715)
* add clip_skip * style * smol fix --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 06ee4db commit beb1c01

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,13 @@ def parse_args(input_args=None):
573573
default=1e-4,
574574
help="Initial learning rate (after the potential warmup period) to use.",
575575
)
576+
parser.add_argument(
577+
"--clip_skip",
578+
type=int,
579+
default=None,
580+
help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
581+
"the output of the pre-final layer will be used for computing the prompt embeddings.",
582+
)
576583

577584
parser.add_argument(
578585
"--text_encoder_lr",
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
12361243

12371244

12381245
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
1239-
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
1246+
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
12401247
prompt_embeds_list = []
12411248

12421249
for i, text_encoder in enumerate(text_encoders):
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
12531260

12541261
# We are only ALWAYS interested in the pooled output of the final text encoder
12551262
pooled_prompt_embeds = prompt_embeds[0]
1256-
prompt_embeds = prompt_embeds[-1][-2]
1263+
if clip_skip is None:
1264+
prompt_embeds = prompt_embeds[-1][-2]
1265+
else:
1266+
# "2" because SDXL always indexes from the penultimate layer.
1267+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
12571268
bs_embed, seq_len, _ = prompt_embeds.shape
12581269
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
12591270
prompt_embeds_list.append(prompt_embeds)
@@ -1830,9 +1841,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None):
18301841
tokenizers = [tokenizer_one, tokenizer_two]
18311842
text_encoders = [text_encoder_one, text_encoder_two]
18321843

1833-
def compute_text_embeddings(prompt, text_encoders, tokenizers):
1844+
def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
18341845
with torch.no_grad():
1835-
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
1846+
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
18361847
prompt_embeds = prompt_embeds.to(accelerator.device)
18371848
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
18381849
return prompt_embeds, pooled_prompt_embeds
@@ -1842,7 +1853,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18421853
# the redundant encoding.
18431854
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
18441855
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
1845-
args.instance_prompt, text_encoders, tokenizers
1856+
args.instance_prompt, text_encoders, tokenizers, args.clip_skip
18461857
)
18471858

18481859
# Handle class prompt for prior-preservation.
@@ -2052,7 +2063,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20522063
if train_dataset.custom_instance_prompts:
20532064
if freeze_text_encoder:
20542065
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
2055-
prompts, text_encoders, tokenizers
2066+
prompts, text_encoders, tokenizers, args.clip_skip
20562067
)
20572068

20582069
else:
@@ -2147,6 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21472158
tokenizers=None,
21482159
prompt=None,
21492160
text_input_ids_list=[tokens_one, tokens_two],
2161+
clip_skip=args.clip_skip,
21502162
)
21512163
unet_added_conditions.update(
21522164
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}

0 commit comments

Comments
 (0)