Skip to content

Commit ba4dece

Browse files
committed
revert unwrap_model change temp
1 parent 0565932 commit ba4dece

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ def _encode_prompt_with_t5(
941941

942942
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
943943

944-
dtype = unwrap_model(text_encoder).dtype
944+
dtype = text_encoder.dtype
945945
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
946946

947947
_, seq_len, _ = prompt_embeds.shape
@@ -984,7 +984,7 @@ def _encode_prompt_with_clip(
984984

985985
# Use pooled output of CLIPTextModel
986986
prompt_embeds = prompt_embeds.pooler_output
987-
prompt_embeds = prompt_embeds.to(dtype=unwrap_model(text_encoder).dtype, device=device)
987+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
988988

989989
# duplicate text embeddings for each generation per prompt, using mps friendly method
990990
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1003,7 +1003,7 @@ def encode_prompt(
10031003
text_input_ids_list=None,
10041004
):
10051005
prompt = [prompt] if isinstance(prompt, str) else prompt
1006-
dtype = unwrap_model(text_encoders[0]).dtype
1006+
dtype = text_encoders[0].dtype
10071007

10081008
pooled_prompt_embeds = _encode_prompt_with_clip(
10091009
text_encoder=text_encoders[0],

0 commit comments

Comments
 (0)