Skip to content

Commit bf9190f

Browse files
committed
fix
1 parent 7c54eb1 commit bf9190f

File tree

3 files changed

+89
-54
lines changed

3 files changed

+89
-54
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,6 @@ def __call__(
496496
max_sequence_length=max_sequence_length,
497497
)
498498

499-
if self.do_classifier_free_guidance:
500-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
501-
502499
# 4. Prepare timesteps
503500
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
504501

@@ -531,7 +528,7 @@ def __call__(
531528
self._current_timestep = t
532529
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
533530

534-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
531+
latent_model_input = latents
535532
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
536533
latent_model_input = latent_model_input.to(transformer_dtype)
537534

@@ -543,13 +540,29 @@ def __call__(
543540
padding_mask=padding_mask,
544541
return_dict=False,
545542
)[0]
543+
if self.do_classifier_free_guidance:
544+
noise_pred_uncond = self.transformer(
545+
hidden_states=latent_model_input,
546+
timestep=timestep,
547+
encoder_hidden_states=negative_prompt_embeds,
548+
fps=fps,
549+
padding_mask=padding_mask,
550+
return_dict=False,
551+
)[0]
552+
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
553+
554+
# pred_original_sample (x0)
555+
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[1]
556+
self.scheduler._step_index -= 1
546557

547558
if self.do_classifier_free_guidance:
548-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
549-
noise_pred = noise_pred_text + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
559+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
560+
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
550561

551-
# compute the previous noisy sample x_t -> x_t-1
552-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
562+
# pred_sample (eps)
563+
latents = self.scheduler.step(
564+
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
565+
)[0]
553566

554567
if callback_on_step_end is not None:
555568
callback_kwargs = {}
@@ -559,6 +572,7 @@ def __call__(
559572

560573
latents = callback_outputs.pop("latents", latents)
561574
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
575+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
562576

563577
# call the callback, if provided
564578
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def prepare_latents(
318318
height: int = 704,
319319
width: int = 1280,
320320
num_frames: int = 121,
321+
do_classifier_free_guidance: bool = True,
321322
input_frames_guidance: bool = False,
322323
dtype: Optional[torch.dtype] = None,
323324
device: Optional[torch.device] = None,
@@ -331,11 +332,12 @@ def prepare_latents(
331332
)
332333

333334
num_cond_frames = video.size(2)
334-
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
335335
if num_cond_frames >= num_frames:
336336
# Take the last `num_frames` frames for conditioning
337+
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
337338
video = video[:, :, -num_frames:]
338339
else:
340+
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
339341
num_padding_frames = num_frames - num_cond_frames
340342
padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
341343
video = torch.cat([video, padding], dim=2)
@@ -374,22 +376,25 @@ def prepare_latents(
374376
if latents is None:
375377
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
376378
else:
377-
latents = latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
379+
latents = latents.to(device=device, dtype=dtype)
378380

379381
latents = latents * self.scheduler.config.sigma_max
380382

381-
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
382-
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
383-
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
384-
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
385-
386383
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
387384
ones_padding = latents.new_ones(padding_shape)
388385
zeros_padding = latents.new_zeros(padding_shape)
386+
387+
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
388+
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
389389
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
390-
uncond_mask = zeros_padding
391-
if input_frames_guidance:
392-
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
390+
391+
uncond_indicator = uncond_mask = None
392+
if do_classifier_free_guidance:
393+
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
394+
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
395+
uncond_mask = zeros_padding
396+
if not input_frames_guidance:
397+
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
393398

394399
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
395400

@@ -599,24 +604,24 @@ def __call__(
599604
height,
600605
width,
601606
num_frames,
607+
self.do_classifier_free_guidance,
602608
input_frames_guidance,
603609
torch.float32,
604610
device,
605611
generator,
606612
latents,
607613
)
608-
uncond_mask = uncond_mask.to(transformer_dtype)
609614
cond_mask = cond_mask.to(transformer_dtype)
615+
if self.do_classifier_free_guidance:
616+
uncond_mask = uncond_mask.to(transformer_dtype)
617+
610618
augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
611619
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
612620

613621
# 6. Denoising loop
614622
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
615623
self._num_timesteps = len(timesteps)
616624

617-
if not guidance_scale > 1.0:
618-
raise ValueError("Running inference without CFG is not yet supported. Please set `guidance_scale > 1`.")
619-
620625
with self.progress_bar(total=num_inference_steps) as progress_bar:
621626
for i, t in enumerate(timesteps):
622627
if self.interrupt:
@@ -628,31 +633,14 @@ def __call__(
628633
current_sigma = self.scheduler.sigmas[i]
629634
is_augment_sigma_greater = augment_sigma >= current_sigma
630635

631-
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
632-
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
633-
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
634-
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
635-
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
636-
637636
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
638637
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
639638
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
640-
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
641639
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
642-
643-
uncond_latent = uncond_latent.to(transformer_dtype)
640+
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
644641
cond_latent = cond_latent.to(transformer_dtype)
645642

646-
noise_pred_uncond = self.transformer(
647-
hidden_states=uncond_latent,
648-
timestep=timestep,
649-
encoder_hidden_states=negative_prompt_embeds,
650-
fps=fps,
651-
condition_mask=uncond_mask,
652-
padding_mask=padding_mask,
653-
return_dict=False,
654-
)[0]
655-
noise_pred_cond = self.transformer(
643+
noise_pred = self.transformer(
656644
hidden_states=cond_latent,
657645
timestep=timestep,
658646
encoder_hidden_states=prompt_embeds,
@@ -662,18 +650,48 @@ def __call__(
662650
return_dict=False,
663651
)[0]
664652

665-
noise_pred = torch.cat([noise_pred_uncond, noise_pred_cond], dim=0)
666-
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
667-
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
668-
669-
noise_pred_cond = (
670-
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
671-
)
672-
noise_pred_uncond = (
673-
current_uncond_indicator * conditioning_latents
674-
+ (1 - current_uncond_indicator) * noise_pred_uncond
675-
)
676-
latents = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
653+
if self.do_classifier_free_guidance:
654+
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
655+
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
656+
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
657+
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
658+
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
659+
uncond_latent = uncond_latent.to(transformer_dtype)
660+
661+
noise_pred_uncond = self.transformer(
662+
hidden_states=uncond_latent,
663+
timestep=timestep,
664+
encoder_hidden_states=negative_prompt_embeds,
665+
fps=fps,
666+
condition_mask=uncond_mask,
667+
padding_mask=padding_mask,
668+
return_dict=False,
669+
)[0]
670+
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
671+
672+
# pred_original_sample (x0)
673+
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[1]
674+
self.scheduler._step_index -= 1
675+
676+
if self.do_classifier_free_guidance:
677+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
678+
noise_pred_uncond = (
679+
current_uncond_indicator * conditioning_latents
680+
+ (1 - current_uncond_indicator) * noise_pred_uncond
681+
)
682+
noise_pred_cond = (
683+
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
684+
)
685+
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
686+
else:
687+
noise_pred = (
688+
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
689+
)
690+
691+
# pred_sample (eps)
692+
latents = self.scheduler.step(
693+
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
694+
)[0]
677695

678696
if callback_on_step_end is not None:
679697
callback_kwargs = {}
@@ -683,6 +701,7 @@ def __call__(
683701

684702
latents = callback_outputs.pop("latents", latents)
685703
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
704+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
686705

687706
# call the callback, if provided
688707
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def step(
318318
s_noise: float = 1.0,
319319
generator: Optional[torch.Generator] = None,
320320
return_dict: bool = True,
321+
pred_original_sample: Optional[torch.Tensor] = None,
321322
) -> Union[EDMEulerSchedulerOutput, Tuple]:
322323
"""
323324
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -381,7 +382,8 @@ def step(
381382
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
382383

383384
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
384-
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
385+
if pred_original_sample is None:
386+
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
385387

386388
# 2. Convert to an ODE derivative
387389
derivative = (sample - pred_original_sample) / sigma_hat

0 commit comments

Comments
 (0)