Skip to content

Commit 0729c66

Browse files
committed
add module.dtype fix to dreambooth script
1 parent 7492e92 commit 0729c66

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
895895

896896
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
897897

898-
dtype = text_encoder.dtype
898+
if hasattr(text_encoder, "module"):
899+
dtype = text_encoder.module.dtype
900+
else:
901+
dtype = text_encoder.dtype
899902
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
900903

901904
_, seq_len, _ = prompt_embeds.shape
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
936939

937940
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
938941

942+
if hasattr(text_encoder, "module"):
943+
dtype = text_encoder.module.dtype
944+
else:
945+
dtype = text_encoder.dtype
939946
# Use pooled output of CLIPTextModel
940947
prompt_embeds = prompt_embeds.pooler_output
941-
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
948+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
942949

943950
# duplicate text embeddings for each generation per prompt, using mps friendly method
944951
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -958,7 +965,12 @@ def encode_prompt(
958965
):
959966
prompt = [prompt] if isinstance(prompt, str) else prompt
960967
batch_size = len(prompt)
961-
dtype = text_encoders[0].dtype
968+
969+
if hasattr(text_encoders[0], "module"):
970+
dtype = text_encoders[0].module.dtype
971+
else:
972+
dtype = text_encoders[0].dtype
973+
962974
device = device if device is not None else text_encoders[1].device
963975
pooled_prompt_embeds = _encode_prompt_with_clip(
964976
text_encoder=text_encoders[0],

0 commit comments

Comments
 (0)