Skip to content

Commit b69f149

Browse files
committed
fix latent caching
1 parent 93f5e04 commit b69f149

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15121512
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
15131513
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
15141514

1515+
vae_config_shift_factor = vae.config.shift_factor
1516+
vae_config_scaling_factor = vae.config.scaling_factor
15151517
if args.cache_latents:
15161518
latents_cache = []
15171519
for batch in tqdm(train_dataloader, desc="Caching latents"):
@@ -1685,7 +1687,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16851687
else:
16861688
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
16871689
model_input = vae.encode(pixel_values).latent_dist.sample()
1688-
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
1690+
1691+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16891692
model_input = model_input.to(dtype=weight_dtype)
16901693

16911694
# Sample noise that we'll add to the latents

0 commit comments

Comments
 (0)