Skip to content

Commit c155f22

Browse files
committed
add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model
1 parent ba4dece commit c155f22

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,10 @@ def _encode_prompt_with_t5(
941941

942942
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
943943

944-
dtype = text_encoder.dtype
944+
if hasattr(text_encoder, "module"):
945+
dtype = text_encoder.module.dtype
946+
else:
947+
dtype = text_encoder.dtype
945948
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
946949

947950
_, seq_len, _ = prompt_embeds.shape
@@ -982,9 +985,13 @@ def _encode_prompt_with_clip(
982985

983986
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
984987

988+
if hasattr(text_encoder, "module"):
989+
dtype = text_encoder.module.dtype
990+
else:
991+
dtype = text_encoder.dtype
985992
# Use pooled output of CLIPTextModel
986993
prompt_embeds = prompt_embeds.pooler_output
987-
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
994+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
988995

989996
# duplicate text embeddings for each generation per prompt, using mps friendly method
990997
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1003,7 +1010,11 @@ def encode_prompt(
10031010
text_input_ids_list=None,
10041011
):
10051012
prompt = [prompt] if isinstance(prompt, str) else prompt
1006-
dtype = text_encoders[0].dtype
1013+
1014+
if hasattr(text_encoders[0], "module"):
1015+
dtype = text_encoders[0].module.dtype
1016+
else:
1017+
dtype = text_encoders[0].dtype
10071018

10081019
pooled_prompt_embeds = _encode_prompt_with_clip(
10091020
text_encoder=text_encoders[0],
@@ -1628,7 +1639,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16281639
if args.train_text_encoder:
16291640
text_encoder_one.train()
16301641
# set top parameter requires_grad = True for gradient checkpointing works
1631-
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1642+
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
16321643

16331644
for step, batch in enumerate(train_dataloader):
16341645
models_to_accumulate = [transformer]
@@ -1719,7 +1730,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17191730
)
17201731

17211732
# handle guidance
1722-
if accelerator.unwrap_model(transformer).config.guidance_embeds:
1733+
if unwrap_model(transformer).config.guidance_embeds:
17231734
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
17241735
guidance = guidance.expand(model_input.shape[0])
17251736
else:
@@ -1837,9 +1848,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18371848
pipeline = FluxPipeline.from_pretrained(
18381849
args.pretrained_model_name_or_path,
18391850
vae=vae,
1840-
text_encoder=accelerator.unwrap_model(text_encoder_one),
1841-
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1842-
transformer=accelerator.unwrap_model(transformer),
1851+
text_encoder=unwrap_model(text_encoder_one),
1852+
text_encoder_2=unwrap_model(text_encoder_two),
1853+
transformer=unwrap_model(transformer),
18431854
revision=args.revision,
18441855
variant=args.variant,
18451856
torch_dtype=weight_dtype,

0 commit comments

Comments
 (0)