Skip to content

Commit 8499008

Browse files
committed
fix tests
1 parent f3b427d commit 8499008

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,8 @@ def __call__(
546546
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
547547
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
548548
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
549-
if self.scheduler.config.final_sigmas_type == "sigma_min":
549+
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
550550
# Replace the last sigma (which is zero) with the minimum sigma value
551-
timesteps[-1] = timesteps[-2]
552551
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
553552

554553
# 5. Prepare latent variables

src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ def __call__(
633633
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
634634
if self.scheduler.config.final_sigmas_type == "sigma_min":
635635
# Replace the last sigma (which is zero) with the minimum sigma value
636-
timesteps[-1] = timesteps[-2]
637636
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
638637

639638
# 5. Prepare latent variables
@@ -687,7 +686,9 @@ def __call__(
687686
c_in = 1 - current_t
688687
c_skip = 1 - current_t
689688
c_out = -current_t
690-
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
689+
timestep = current_t.view(1, 1, 1, 1, 1).expand(
690+
latents.size(0), -1, latents.size(2), -1, -1
691+
) # [B, 1, T, 1, 1]
691692

692693
cond_latent = latents * c_in
693694
cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent

tests/pipelines/cosmos/test_cosmos2_text2image.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
import torch
2323
from transformers import AutoTokenizer, T5EncoderModel
2424

25-
from diffusers import AutoencoderKLWan, Cosmos2TextToImagePipeline, CosmosTransformer3DModel, FlowMatchEulerDiscreteScheduler
25+
from diffusers import (
26+
AutoencoderKLWan,
27+
Cosmos2TextToImagePipeline,
28+
CosmosTransformer3DModel,
29+
FlowMatchEulerDiscreteScheduler,
30+
)
2631
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2732

2833
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS

tests/pipelines/cosmos/test_cosmos2_video2world.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
import torch
2424
from transformers import AutoTokenizer, T5EncoderModel
2525

26-
from diffusers import AutoencoderKLWan, Cosmos2VideoToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
26+
from diffusers import (
27+
AutoencoderKLWan,
28+
Cosmos2VideoToWorldPipeline,
29+
CosmosTransformer3DModel,
30+
FlowMatchEulerDiscreteScheduler,
31+
)
2732
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2833

2934
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS

0 commit comments

Comments
 (0)