Skip to content

Commit 7492e92

Browse files
committed
make changes for distributed training + unify unwrap_model calls in advanced script
1 parent 9c4368d commit 7492e92

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,10 @@ def _encode_prompt_with_t5(
11961196

11971197
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
11981198

1199-
dtype = text_encoder.dtype
1199+
if hasattr(text_encoder, "module"):
1200+
dtype = text_encoder.module.dtype
1201+
else:
1202+
dtype = text_encoder.dtype
12001203
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
12011204

12021205
_, seq_len, _ = prompt_embeds.shape
@@ -1237,9 +1240,13 @@ def _encode_prompt_with_clip(
12371240

12381241
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
12391242

1243+
if hasattr(text_encoder, "module"):
1244+
dtype = text_encoder.module.dtype
1245+
else:
1246+
dtype = text_encoder.dtype
12401247
# Use pooled output of CLIPTextModel
12411248
prompt_embeds = prompt_embeds.pooler_output
1242-
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
1249+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
12431250

12441251
# duplicate text embeddings for each generation per prompt, using mps friendly method
12451252
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1258,7 +1265,10 @@ def encode_prompt(
12581265
text_input_ids_list=None,
12591266
):
12601267
prompt = [prompt] if isinstance(prompt, str) else prompt
1261-
dtype = text_encoders[0].dtype
1268+
if hasattr(text_encoders[0], "module"):
1269+
dtype = text_encoders[0].module.dtype
1270+
else:
1271+
dtype = text_encoders[0].dtype
12621272

12631273
pooled_prompt_embeds = _encode_prompt_with_clip(
12641274
text_encoder=text_encoders[0],
@@ -2040,7 +2050,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20402050
if args.train_text_encoder:
20412051
text_encoder_one.train()
20422052
# set top parameter requires_grad = True for gradient checkpointing works
2043-
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
2053+
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
20442054
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
20452055
text_encoder_one.train()
20462056
if args.enable_t5_ti:
@@ -2148,7 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21482158
)
21492159

21502160
# handle guidance
2151-
if accelerator.unwrap_model(transformer).config.guidance_embeds:
2161+
if unwrap_model(transformer).config.guidance_embeds:
21522162
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
21532163
guidance = guidance.expand(model_input.shape[0])
21542164
else:
@@ -2290,9 +2300,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22902300
pipeline = FluxPipeline.from_pretrained(
22912301
args.pretrained_model_name_or_path,
22922302
vae=vae,
2293-
text_encoder=accelerator.unwrap_model(text_encoder_one),
2294-
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
2295-
transformer=accelerator.unwrap_model(transformer),
2303+
text_encoder=unwrap_model(text_encoder_one),
2304+
text_encoder_2=unwrap_model(text_encoder_two),
2305+
transformer=unwrap_model(transformer),
22962306
revision=args.revision,
22972307
variant=args.variant,
22982308
torch_dtype=weight_dtype,

0 commit comments

Comments
 (0)