Skip to content

Commit 64fc4fe

Browse files
committed
update
1 parent bf9190f commit 64fc4fe

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,13 @@ def __call__(
633633
current_sigma = self.scheduler.sigmas[i]
634634
is_augment_sigma_greater = augment_sigma >= current_sigma
635635

636+
c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
637+
c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
638+
636639
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
637640
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
638641
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
642+
cond_latent = cond_latent * c_in_augment / c_in_original
639643
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
640644
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
641645
cond_latent = cond_latent.to(transformer_dtype)
@@ -654,6 +658,7 @@ def __call__(
654658
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
655659
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
656660
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
661+
uncond_latent = uncond_latent * c_in_augment / c_in_original
657662
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
658663
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
659664
uncond_latent = uncond_latent.to(transformer_dtype)

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def set_begin_index(self, begin_index: int = 0):
161161
self._begin_index = begin_index
162162

163163
def precondition_inputs(self, sample, sigma):
164-
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
164+
c_in = self._get_conditioning_c_in(sigma)
165165
scaled_sample = sample * c_in
166166
return scaled_sample
167167

@@ -440,5 +440,9 @@ def add_noise(
440440
noisy_samples = original_samples + noise * sigma
441441
return noisy_samples
442442

443+
def _get_conditioning_c_in(self, sigma):
444+
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
445+
return c_in
446+
443447
def __len__(self):
444448
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)