Skip to content

Commit bdd6cae

Browse files
committed
add intermediate embeddings saving when checkpointing is enabled
1 parent 90e9517 commit bdd6cae

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,6 +2247,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22472247

22482248
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
22492249
accelerator.save_state(save_path)
2250+
if args.train_text_encoder_ti:
2251+
embedding_handler.save_embeddings(
2252+
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors")
22502253
logger.info(f"Saved state to {save_path}")
22512254

22522255
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}

0 commit comments

Comments
 (0)