Skip to content

Commit a571ac9

Browse files
committed
fix caching latents
1 parent b3a7860 commit a571ac9

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15121512
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
15131513
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
15141514

1515+
if args.cache_latents:
1516+
latents_cache = []
1517+
for batch in tqdm(train_dataloader, desc="Caching latents"):
1518+
with torch.no_grad():
1519+
batch["pixel_values"] = batch["pixel_values"].to(
1520+
accelerator.device, non_blocking=True, dtype=weight_dtype
1521+
)
1522+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
1523+
1524+
if args.validation_prompt is None:
1525+
del vae
1526+
free_memory()
1527+
1528+
15151529
# Scheduler and math around the number of training steps.
15161530
overrode_max_train_steps = False
15171531
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

0 commit comments

Comments
 (0)