@@ -1844,6 +1844,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18441844 sigma = sigma .unsqueeze (- 1 )
18451845 return sigma
18461846
1847+ has_guidance = unwrap_model (transformer ).config .guidance_embeds
18471848 for epoch in range (first_epoch , args .num_train_epochs ):
18481849 transformer .train ()
18491850 if args .train_text_encoder :
@@ -1906,10 +1907,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19061907 cond_pixel_values = batch ["cond_pixel_values" ].to (dtype = vae .dtype )
19071908 if args .vae_encode_mode == "sample" :
19081909 model_input = vae .encode (pixel_values ).latent_dist .sample ()
1909- cond_model_input = vae .encode (cond_pixel_values ).latent_dist .sample ()
1910+ if has_image_input :
1911+ cond_model_input = vae .encode (cond_pixel_values ).latent_dist .sample ()
19101912 else :
19111913 model_input = vae .encode (pixel_values ).latent_dist .mode ()
1912- cond_model_input = vae .encode (cond_pixel_values ).latent_dist .mode ()
1914+ if has_image_input :
1915+ cond_model_input = vae .encode (cond_pixel_values ).latent_dist .mode ()
19131916 model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
19141917 model_input = model_input .to (dtype = weight_dtype )
19151918 if has_image_input :
@@ -1975,8 +1978,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19751978 packed_noisy_model_input = torch .cat ([packed_noisy_model_input , packed_cond_input ], dim = 1 )
19761979
19771980 # Kontext always has guidance
1978- guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
1979- guidance = guidance .expand (model_input .shape [0 ])
1981+ guidance = None
1982+ if has_guidance :
1983+ guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
1984+ guidance = guidance .expand (model_input .shape [0 ])
19801985
19811986 # Predict the noise residual
19821987 model_pred = transformer (
0 commit comments