Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):

if args.cache_latents:
latents_cache = []
# Store vae config before potential deletion
vae_scaling_factor = vae.config.scaling_factor
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
Expand All @@ -1940,6 +1942,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
vae_scaling_factor = vae.config.scaling_factor

# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
Expand Down Expand Up @@ -2109,13 +2113,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = vae.encode(pixel_values).latent_dist.sample()

if latents_mean is None and latents_std is None:
model_input = model_input * vae.config.scaling_factor
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)

# Sample noise that we'll add to the latents
Expand Down