Skip to content

Commit 710fcae

Browse files
committed
fix validation
1 parent d434db3 commit 710fcae

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def log_validation(
230230
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
231231
autocast_ctx = torch.autocast(accelerator.device.type)
232232

233+
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
233234
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
234235
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
235236
)
@@ -2194,16 +2195,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21942195
accelerator.backward(loss)
21952196
if accelerator.sync_gradients:
21962197
if not freeze_text_encoder:
2197-
if args.train_text_encoder:
2198+
if args.train_text_encoder: # text encoder tuning
21982199
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
21992200
elif pure_textual_inversion:
2200-
params_to_clip = itertools.chain(
2201-
text_encoder_one.parameters(), text_encoder_two.parameters()
2202-
)
2201+
if args.enable_t5_ti:
2202+
params_to_clip = itertools.chain(
2203+
text_encoder_one.parameters(), text_encoder_two.parameters()
2204+
)
2205+
else:
2206+
params_to_clip = itertools.chain(
2207+
text_encoder_one.parameters()
2208+
)
22032209
else:
2204-
params_to_clip = itertools.chain(
2205-
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
2206-
)
2210+
if args.enable_t5_ti:
2211+
params_to_clip = itertools.chain(
2212+
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
2213+
)
2214+
else:
2215+
params_to_clip = itertools.chain(transformer.parameters(),
2216+
text_encoder_one.parameters())
22072217
else:
22082218
params_to_clip = itertools.chain(transformer.parameters())
22092219
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -2260,8 +2270,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22602270
if accelerator.is_main_process:
22612271
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
22622272
# create pipeline
2263-
if freeze_text_encoder:
2273+
if freeze_text_encoder: # no text encoder one, two optimizations
22642274
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
2275+
text_encoder_one.to(weight_dtype)
2276+
text_encoder_two.to(weight_dtype)
2277+
22652278
pipeline = FluxPipeline.from_pretrained(
22662279
args.pretrained_model_name_or_path,
22672280
vae=vae,
@@ -2287,9 +2300,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22872300
if freeze_text_encoder:
22882301
del text_encoder_one, text_encoder_two
22892302
free_memory()
2290-
elif args.train_text_encoder:
2291-
del text_encoder_two
2292-
free_memory()
22932303

22942304
# Save the lora layers
22952305
accelerator.wait_for_everyone()

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def log_validation(
184184
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
185185
autocast_ctx = torch.autocast(accelerator.device.type)
186186

187+
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
187188
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
188189
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
189190
)

0 commit comments

Comments
 (0)