Skip to content

Commit 6226c8d

Browse files
committed
update
1 parent f046889 commit 6226c8d

File tree

7 files changed

+68
-29
lines changed

7 files changed

+68
-29
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
3636
- all
3737
- __call__
3838

39+
## Cosmos2TextToImagePipeline
40+
41+
[[autodoc]] Cosmos2TextToImagePipeline
42+
- all
43+
- __call__
44+
45+
## Cosmos2VideoToWorldPipeline
46+
47+
[[autodoc]] Cosmos2VideoToWorldPipeline
48+
- all
49+
- __call__
50+
3951
## CosmosPipelineOutput
4052

4153
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
54+
55+
## CosmosImagePipelineOutput
56+
57+
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def forward(
568568
hidden_states = self.proj_out(hidden_states)
569569
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
570570
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
571+
# NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
572+
# It might be a source of confusion to the reader, but this is correct
571573
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
572574
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
573575

src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
24-
from ...schedulers import EDMEulerScheduler
24+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2525
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
2626
from ...utils.torch_utils import randn_tensor
2727
from ...video_processor import VideoProcessor
@@ -134,7 +134,7 @@ def retrieve_timesteps(
134134

135135
class Cosmos2TextToImagePipeline(DiffusionPipeline):
136136
r"""
137-
Pipeline for text-to-image generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
137+
Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
138138
139139
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
140140
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -149,7 +149,7 @@ class Cosmos2TextToImagePipeline(DiffusionPipeline):
149149
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
150150
transformer ([`CosmosTransformer3DModel`]):
151151
Conditional Transformer to denoise the encoded image latents.
152-
scheduler ([`EDMEulerScheduler`]):
152+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153153
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154154
vae ([`AutoencoderKLWan`]):
155155
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
@@ -166,7 +166,7 @@ def __init__(
166166
tokenizer: T5TokenizerFast,
167167
transformer: CosmosTransformer3DModel,
168168
vae: AutoencoderKLWan,
169-
scheduler: EDMEulerScheduler,
169+
scheduler: FlowMatchEulerDiscreteScheduler,
170170
safety_checker: CosmosSafetyChecker = None,
171171
):
172172
super().__init__()
@@ -543,7 +543,8 @@ def __call__(
543543
)
544544

545545
# 4. Prepare timesteps
546-
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=torch.float64)
546+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
547+
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
547548
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
548549
if self.scheduler.config.final_sigmas_type == "sigma_min":
549550
# Replace the last sigma (which is zero) with the minimum sigma value
@@ -577,14 +578,13 @@ def __call__(
577578
continue
578579

579580
self._current_timestep = t
580-
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
581581
current_sigma = self.scheduler.sigmas[i]
582582

583583
current_t = current_sigma / (current_sigma + 1)
584584
c_in = 1 - current_t
585585
c_skip = 1 - current_t
586586
c_out = -current_t
587-
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype)
587+
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
588588

589589
latent_model_input = latents * c_in
590590
latent_model_input = latent_model_input.to(transformer_dtype)

src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...image_processor import PipelineImageInput
2424
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
25-
from ...schedulers import EDMEulerScheduler
25+
from ...schedulers import FlowMatchEulerDiscreteScheduler
2626
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
2727
from ...utils.torch_utils import randn_tensor
2828
from ...video_processor import VideoProcessor
@@ -153,7 +153,7 @@ def retrieve_latents(
153153

154154
class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
155155
r"""
156-
Pipeline for text-to-image generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
156+
Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
157157
158158
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
159159
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -168,7 +168,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
168168
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
169169
transformer ([`CosmosTransformer3DModel`]):
170170
Conditional Transformer to denoise the encoded image latents.
171-
scheduler ([`EDMEulerScheduler`]):
171+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
172172
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
173173
vae ([`AutoencoderKLWan`]):
174174
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
@@ -185,7 +185,7 @@ def __init__(
185185
tokenizer: T5TokenizerFast,
186186
transformer: CosmosTransformer3DModel,
187187
vae: AutoencoderKLWan,
188-
scheduler: EDMEulerScheduler,
188+
scheduler: FlowMatchEulerDiscreteScheduler,
189189
safety_checker: CosmosSafetyChecker = None,
190190
):
191191
super().__init__()
@@ -206,6 +206,18 @@ def __init__(
206206
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
207207
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
208208

209+
self.sigma_max = 80.0
210+
self.sigma_min = 0.002
211+
self.sigma_data = 1.0
212+
self.final_sigmas_type = "sigma_min"
213+
if self.scheduler is not None:
214+
self.scheduler.register_to_config(
215+
sigma_max=self.sigma_max,
216+
sigma_min=self.sigma_min,
217+
sigma_data=self.sigma_data,
218+
final_sigmas_type=self.final_sigmas_type,
219+
)
220+
209221
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
210222
def _get_t5_prompt_embeds(
211223
self,
@@ -340,7 +352,7 @@ def prepare_latents(
340352
num_channels_latents: 16,
341353
height: int = 704,
342354
width: int = 1280,
343-
num_frames: int = 77,
355+
num_frames: int = 93,
344356
do_classifier_free_guidance: bool = True,
345357
dtype: Optional[torch.dtype] = None,
346358
device: Optional[torch.device] = None,
@@ -472,7 +484,7 @@ def __call__(
472484
negative_prompt: Optional[Union[str, List[str]]] = None,
473485
height: int = 704,
474486
width: int = 1280,
475-
num_frames: int = 77,
487+
num_frames: int = 93,
476488
num_inference_steps: int = 35,
477489
guidance_scale: float = 7.0,
478490
fps: int = 16,
@@ -505,7 +517,7 @@ def __call__(
505517
The height in pixels of the generated image.
506518
width (`int`, defaults to `1280`):
507519
The width in pixels of the generated image.
508-
num_frames (`int`, defaults to `77`):
520+
num_frames (`int`, defaults to `93`):
509521
The number of frames in the generated video.
510522
num_inference_steps (`int`, defaults to `35`):
511523
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -616,7 +628,13 @@ def __call__(
616628
)
617629

618630
# 4. Prepare timesteps
619-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
631+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
632+
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
633+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
634+
if self.scheduler.config.final_sigmas_type == "sigma_min":
635+
# Replace the last sigma (which is zero) with the minimum sigma value
636+
timesteps[-1] = timesteps[-2]
637+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
620638

621639
# 5. Prepare latent variables
622640
vae_dtype = self.vae.dtype
@@ -651,7 +669,7 @@ def __call__(
651669

652670
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
653671
sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
654-
t_conditioning = self.scheduler.precondition_noise(sigma_conditioning)
672+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
655673

656674
# 6. Denoising loop
657675
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -663,12 +681,15 @@ def __call__(
663681
continue
664682

665683
self._current_timestep = t
666-
timestep = t.view(1, 1, 1, 1, 1).expand(
667-
latents.size(0), -1, latents.size(2), -1, -1
668-
) # [B, 1, T, 1, 1]
669684
current_sigma = self.scheduler.sigmas[i]
670685

671-
cond_latent = self.scheduler.scale_model_input(latents, t)
686+
current_t = current_sigma / (current_sigma + 1)
687+
c_in = 1 - current_t
688+
c_skip = 1 - current_t
689+
c_out = -current_t
690+
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
691+
692+
cond_latent = latents * c_in
672693
cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
673694
cond_latent = cond_latent.to(transformer_dtype)
674695
cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
@@ -683,11 +704,11 @@ def __call__(
683704
padding_mask=padding_mask,
684705
return_dict=False,
685706
)[0]
686-
noise_pred = self.scheduler.precondition_outputs(latents, noise_pred, current_sigma)
707+
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
687708
noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred
688709

689710
if self.do_classifier_free_guidance:
690-
uncond_latent = self.scheduler.scale_model_input(latents, t)
711+
uncond_latent = latents * c_in
691712
uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
692713
uncond_latent = uncond_latent.to(transformer_dtype)
693714
uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
@@ -702,15 +723,14 @@ def __call__(
702723
padding_mask=padding_mask,
703724
return_dict=False,
704725
)[0]
705-
noise_pred_uncond = self.scheduler.precondition_outputs(latents, noise_pred_uncond, current_sigma)
726+
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
706727
noise_pred_uncond = (
707728
uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond
708729
)
709730
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
710731

711-
latents = self.scheduler.step(
712-
noise_pred, t, latents, pred_original_sample=noise_pred, return_dict=False
713-
)[0]
732+
noise_pred = (latents - noise_pred) / current_sigma
733+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
714734

715735
if callback_on_step_end is not None:
716736
callback_kwargs = {}

src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def retrieve_timesteps(
131131

132132
class CosmosTextToWorldPipeline(DiffusionPipeline):
133133
r"""
134-
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
134+
Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1).
135135
136136
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
137137
implemented for all pipelines (downloading, saving, running on a particular device, etc.).

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def retrieve_latents(
174174

175175
class CosmosVideoToWorldPipeline(DiffusionPipeline):
176176
r"""
177-
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
177+
Pipeline for image-to-world and video-to-world generation using [Cosmos
178+
Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1).
178179
179180
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
180181
implemented for all pipelines (downloading, saving, running on a particular device, etc.).

src/diffusers/pipelines/cosmos/pipeline_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class CosmosPipelineOutput(BaseOutput):
2929
@dataclass
3030
class CosmosImagePipelineOutput(BaseOutput):
3131
"""
32-
Output class for CogView3 pipelines.
32+
Output class for Cosmos any-to-image pipelines.
3333
3434
Args:
3535
images (`List[PIL.Image.Image]` or `np.ndarray`)

0 commit comments

Comments
 (0)