Skip to content

Commit 2b4827f

Browse files
committed
fix
1 parent d2931e0 commit 2b4827f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18441844
sigma = sigma.unsqueeze(-1)
18451845
return sigma
18461846

1847+
has_guidance = unwrap_model(transformer).config.guidance_embeds
18471848
for epoch in range(first_epoch, args.num_train_epochs):
18481849
transformer.train()
18491850
if args.train_text_encoder:
@@ -1906,10 +1907,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19061907
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
19071908
if args.vae_encode_mode == "sample":
19081909
model_input = vae.encode(pixel_values).latent_dist.sample()
1909-
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
1910+
if has_image_input:
1911+
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
19101912
else:
19111913
model_input = vae.encode(pixel_values).latent_dist.mode()
1912-
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
1914+
if has_image_input:
1915+
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
19131916
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
19141917
model_input = model_input.to(dtype=weight_dtype)
19151918
if has_image_input:
@@ -1975,8 +1978,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19751978
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
19761979

19771980
# Kontext always has guidance
1978-
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1979-
guidance = guidance.expand(model_input.shape[0])
1981+
guidance = None
1982+
if has_guidance:
1983+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1984+
guidance = guidance.expand(model_input.shape[0])
19801985

19811986
# Predict the noise residual
19821987
model_pred = transformer(

0 commit comments

Comments
 (0)