Skip to content

Commit 4497b3e

Browse files
authored
[Training] make DreamBooth SDXL LoRA training script compatible with torch.compile (#6483)
* make it torch.compile comaptible * make the text encoder compatible too. * style
1 parent fc63ebd commit 4497b3e

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
780780
text_input_ids = text_input_ids_list[i]
781781

782782
prompt_embeds = text_encoder(
783-
text_input_ids.to(text_encoder.device),
784-
output_hidden_states=True,
783+
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
785784
)
786785

787786
# We are only ALWAYS interested in the pooled output of the final text encoder
788787
pooled_prompt_embeds = prompt_embeds[0]
789-
prompt_embeds = prompt_embeds.hidden_states[-2]
788+
prompt_embeds = prompt_embeds[-1][-2]
790789
bs_embed, seq_len, _ = prompt_embeds.shape
791790
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
792791
prompt_embeds_list.append(prompt_embeds)
@@ -1429,7 +1428,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14291428
timesteps,
14301429
prompt_embeds_input,
14311430
added_cond_kwargs=unet_added_conditions,
1432-
).sample
1431+
return_dict=False,
1432+
)[0]
14331433
else:
14341434
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
14351435
prompt_embeds, pooled_prompt_embeds = encode_prompt(
@@ -1443,8 +1443,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14431443
)
14441444
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
14451445
model_pred = unet(
1446-
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
1447-
).sample
1446+
noisy_model_input,
1447+
timesteps,
1448+
prompt_embeds_input,
1449+
added_cond_kwargs=unet_added_conditions,
1450+
return_dict=False,
1451+
)[0]
14481452

14491453
# Get the target for loss depending on the prediction type
14501454
if noise_scheduler.config.prediction_type == "epsilon":

0 commit comments

Comments
 (0)