@@ -633,9 +633,13 @@ def __call__(
633633 current_sigma = self .scheduler .sigmas [i ]
634634 is_augment_sigma_greater = augment_sigma >= current_sigma
635635
636+ c_in_augment = self .scheduler ._get_conditioning_c_in (augment_sigma )
637+ c_in_original = self .scheduler ._get_conditioning_c_in (current_sigma )
638+
636639 current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
637640 cond_noise = randn_tensor (latents .shape , generator = generator , device = device , dtype = torch .float32 )
638641 cond_latent = conditioning_latents + cond_noise * augment_sigma [:, None , None , None , None ]
642+ cond_latent = cond_latent * c_in_augment / c_in_original
639643 cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator ) * latents
640644 cond_latent = self .scheduler .scale_model_input (cond_latent , t )
641645 cond_latent = cond_latent .to (transformer_dtype )
@@ -654,6 +658,7 @@ def __call__(
654658 current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
655659 uncond_noise = randn_tensor (latents .shape , generator = generator , device = device , dtype = torch .float32 )
656660 uncond_latent = conditioning_latents + uncond_noise * augment_sigma [:, None , None , None , None ]
661+ uncond_latent = uncond_latent * c_in_augment / c_in_original
657662 uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator ) * latents
658663 uncond_latent = self .scheduler .scale_model_input (uncond_latent , t )
659664 uncond_latent = uncond_latent .to (transformer_dtype )
0 commit comments