Skip to content

Commit 0d56c0c

Browse files
committed
support use_flow_sigmas in EDM scheduler instead of maintain cosmos-specific scheduler
1 parent c2ab6c8 commit 0d56c0c

File tree

7 files changed

+114
-512
lines changed

7 files changed

+114
-512
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
CosmosTextToWorldPipeline,
1515
CosmosTransformer3DModel,
1616
EDMEulerScheduler,
17-
FlowMatchEulerEDMCosmos2_0Scheduler,
1817
)
1918

2019

@@ -187,6 +186,51 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
187186
"concat_padding_mask": True,
188187
"extra_pos_embed_type": None,
189188
},
189+
"Cosmos-2.0-Diffusion-14B-Text2Image": {
190+
"in_channels": 16,
191+
"out_channels": 16,
192+
"num_attention_heads": 40,
193+
"attention_head_dim": 128,
194+
"num_layers": 36,
195+
"mlp_ratio": 4.0,
196+
"text_embed_dim": 1024,
197+
"adaln_lora_dim": 256,
198+
"max_size": (128, 240, 240),
199+
"patch_size": (1, 2, 2),
200+
"rope_scale": (20 / 24, 2.0, 2.0),
201+
"concat_padding_mask": True,
202+
"extra_pos_embed_type": None,
203+
},
204+
"Cosmos-2.0-Diffusion-2B-Video2World": {
205+
"in_channels": 16 + 1,
206+
"out_channels": 16,
207+
"num_attention_heads": 16,
208+
"attention_head_dim": 128,
209+
"num_layers": 28,
210+
"mlp_ratio": 4.0,
211+
"text_embed_dim": 1024,
212+
"adaln_lora_dim": 256,
213+
"max_size": (128, 240, 240),
214+
"patch_size": (1, 2, 2),
215+
"rope_scale": (1.0, 1.0, 1.0),
216+
"concat_padding_mask": True,
217+
"extra_pos_embed_type": None,
218+
},
219+
"Cosmos-2.0-Diffusion-14B-Video2World": {
220+
"in_channels": 16 + 1,
221+
"out_channels": 16,
222+
"num_attention_heads": 40,
223+
"attention_head_dim": 128,
224+
"num_layers": 36,
225+
"mlp_ratio": 4.0,
226+
"text_embed_dim": 1024,
227+
"adaln_lora_dim": 256,
228+
"max_size": (128, 240, 240),
229+
"patch_size": (1, 2, 2),
230+
"rope_scale": (20 / 24, 2.0, 2.0),
231+
"concat_padding_mask": True,
232+
"extra_pos_embed_type": None,
233+
},
190234
}
191235

192236
VAE_KEYS_RENAME_DICT = {
@@ -352,8 +396,8 @@ def convert_vae(vae_type: str):
352396
return vae
353397

354398

355-
def save_pipeline_cosmos_1_0(args, transformer, vae, dtype):
356-
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
399+
def save_pipeline_cosmos_1_0(args, transformer, vae):
400+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
357401
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
358402
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
359403
# So, the sigma_min values that is used is the default value of 0.002.
@@ -378,11 +422,11 @@ def save_pipeline_cosmos_1_0(args, transformer, vae, dtype):
378422
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
379423

380424

381-
def save_pipeline_cosmos_2_0(args, transformer, vae, dtype):
382-
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
425+
def save_pipeline_cosmos_2_0(args, transformer, vae):
426+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
383427
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
384428

385-
scheduler = FlowMatchEulerEDMCosmos2_0Scheduler(
429+
scheduler = EDMEulerScheduler(
386430
sigma_min=0.0002,
387431
sigma_max=80,
388432
sigma_data=1.0,
@@ -391,6 +435,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae, dtype):
391435
prediction_type="epsilon",
392436
rho=7.0,
393437
final_sigmas_type="sigma_min",
438+
use_flow_sigmas=True,
394439
)
395440

396441
pipe = CosmosTextToImagePipeline(
@@ -458,8 +503,8 @@ def get_args():
458503

459504
if args.save_pipeline:
460505
if "Cosmos-1.0" in args.transformer_type:
461-
save_pipeline_cosmos_1_0(args, transformer, vae, dtype)
506+
save_pipeline_cosmos_1_0(args, transformer, vae)
462507
elif "Cosmos-2.0" in args.transformer_type:
463-
save_pipeline_cosmos_2_0(args, transformer, vae, dtype)
508+
save_pipeline_cosmos_2_0(args, transformer, vae)
464509
else:
465510
assert False

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@
271271
"EulerAncestralDiscreteScheduler",
272272
"EulerDiscreteScheduler",
273273
"FlowMatchEulerDiscreteScheduler",
274-
"FlowMatchEulerEDMCosmos2_0Scheduler",
275274
"FlowMatchHeunDiscreteScheduler",
276275
"FlowMatchLCMScheduler",
277276
"HeunDiscreteScheduler",
@@ -880,7 +879,6 @@
880879
EulerAncestralDiscreteScheduler,
881880
EulerDiscreteScheduler,
882881
FlowMatchEulerDiscreteScheduler,
883-
FlowMatchEulerEDMCosmos2_0Scheduler,
884882
FlowMatchHeunDiscreteScheduler,
885883
FlowMatchLCMScheduler,
886884
HeunDiscreteScheduler,

src/diffusers/pipelines/cosmos/pipeline_cosmos_text2image.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
24-
from ...schedulers import FlowMatchEulerEDMCosmos2_0Scheduler
24+
from ...schedulers import EDMEulerScheduler
2525
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
2626
from ...utils.torch_utils import randn_tensor
2727
from ...video_processor import VideoProcessor
@@ -56,15 +56,17 @@ def __init__(self, *args, **kwargs):
5656
>>> import torch
5757
>>> from diffusers import CosmosTextToImagePipeline
5858
59-
>>> # TODO(aryan): update model_id
60-
>>> model_id = "/raid/aryan/diffusers-cosmos2-t2i-2B"
59+
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image
60+
>>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
6161
>>> pipe = CosmosTextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
6262
>>> pipe.to("cuda")
6363
6464
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
6565
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
6666
67-
>>> output = pipe(prompt=prompt, height=1024, width=1024, generator=torch.Generator().manual_seed(1)).images[0]
67+
>>> output = pipe(
68+
... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
69+
... ).images[0]
6870
>>> output.save("output.png")
6971
```
7072
"""
@@ -147,7 +149,7 @@ class CosmosTextToImagePipeline(DiffusionPipeline):
147149
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
148150
transformer ([`CosmosTransformer3DModel`]):
149151
Conditional Transformer to denoise the encoded image latents.
150-
scheduler ([`FlowMatchEulerEDMCosmos2_0Scheduler`]):
152+
scheduler ([`EDMEulerScheduler`]):
151153
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
152154
vae ([`AutoencoderKLWan`]):
153155
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
@@ -164,7 +166,7 @@ def __init__(
164166
tokenizer: T5TokenizerFast,
165167
transformer: CosmosTransformer3DModel,
166168
vae: AutoencoderKLWan,
167-
scheduler: FlowMatchEulerEDMCosmos2_0Scheduler,
169+
scheduler: EDMEulerScheduler,
168170
safety_checker: CosmosSafetyChecker = None,
169171
):
170172
super().__init__()
@@ -228,13 +230,13 @@ def _get_t5_prompt_embeds(
228230

229231
return prompt_embeds
230232

231-
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
233+
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt
232234
def encode_prompt(
233235
self,
234236
prompt: Union[str, List[str]],
235237
negative_prompt: Optional[Union[str, List[str]]] = None,
236238
do_classifier_free_guidance: bool = True,
237-
num_videos_per_prompt: int = 1,
239+
num_images_per_prompt: int = 1,
238240
prompt_embeds: Optional[torch.Tensor] = None,
239241
negative_prompt_embeds: Optional[torch.Tensor] = None,
240242
max_sequence_length: int = 512,
@@ -253,7 +255,7 @@ def encode_prompt(
253255
less than `1`).
254256
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
255257
Whether to use classifier free guidance or not.
256-
num_videos_per_prompt (`int`, *optional*, defaults to 1):
258+
num_images_per_prompt (`int`, *optional*, defaults to 1):
257259
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
258260
prompt_embeds (`torch.Tensor`, *optional*):
259261
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
@@ -282,8 +284,8 @@ def encode_prompt(
282284

283285
# duplicate text embeddings for each generation per prompt, using mps friendly method
284286
_, seq_len, _ = prompt_embeds.shape
285-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
286-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
287+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
288+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
287289

288290
if do_classifier_free_guidance and negative_prompt_embeds is None:
289291
negative_prompt = negative_prompt or ""
@@ -307,8 +309,8 @@ def encode_prompt(
307309

308310
# duplicate text embeddings for each generation per prompt, using mps friendly method
309311
_, seq_len, _ = negative_prompt_embeds.shape
310-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
311-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
312+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
313+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
312314

313315
return prompt_embeds, negative_prompt_embeds
314316

@@ -402,7 +404,7 @@ def __call__(
402404
width: int = 1280,
403405
num_inference_steps: int = 35,
404406
guidance_scale: float = 7.0,
405-
num_videos_per_prompt: Optional[int] = 1,
407+
num_images_per_prompt: Optional[int] = 1,
406408
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
407409
latents: Optional[torch.Tensor] = None,
408410
prompt_embeds: Optional[torch.Tensor] = None,
@@ -434,7 +436,7 @@ def __call__(
434436
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
435437
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
436438
`guidance_scale > 1`.
437-
num_videos_per_prompt (`int`, *optional*, defaults to 1):
439+
num_images_per_prompt (`int`, *optional*, defaults to 1):
438440
The number of images to generate per prompt.
439441
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
440442
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
@@ -521,7 +523,7 @@ def __call__(
521523
prompt=prompt,
522524
negative_prompt=negative_prompt,
523525
do_classifier_free_guidance=self.do_classifier_free_guidance,
524-
num_videos_per_prompt=num_videos_per_prompt,
526+
num_images_per_prompt=num_images_per_prompt,
525527
prompt_embeds=prompt_embeds,
526528
negative_prompt_embeds=negative_prompt_embeds,
527529
device=device,
@@ -535,7 +537,7 @@ def __call__(
535537
transformer_dtype = self.transformer.dtype
536538
num_channels_latents = self.transformer.config.in_channels
537539
latents = self.prepare_latents(
538-
batch_size * num_videos_per_prompt,
540+
batch_size * num_images_per_prompt,
539541
num_channels_latents,
540542
height,
541543
width,
@@ -612,7 +614,7 @@ def __call__(
612614
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
613615
latents.device, latents.dtype
614616
)
615-
latents = latents / latents_std + latents_mean
617+
latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean
616618
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
617619

618620
if self.safety_checker is not None:

src/diffusers/schedulers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
_import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]
6060
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
6161
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
62-
_import_structure["scheduling_flow_match_euler_edm_cosmos2"] = ["FlowMatchEulerEDMCosmos2_0Scheduler"]
6362
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
6463
_import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
6564
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
@@ -162,7 +161,6 @@
162161
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
163162
from .scheduling_euler_discrete import EulerDiscreteScheduler
164163
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
165-
from .scheduling_flow_match_euler_edm_cosmos2 import FlowMatchEulerEDMCosmos2_0Scheduler
166164
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
167165
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
168166
from .scheduling_heun_discrete import HeunDiscreteScheduler

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
prediction_type: str = "epsilon",
9797
rho: float = 7.0,
9898
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
99+
use_flow_sigmas: bool = False,
99100
):
100101
if sigma_schedule not in ["karras", "exponential"]:
101102
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -169,24 +170,18 @@ def precondition_noise(self, sigma):
169170
if not isinstance(sigma, torch.Tensor):
170171
sigma = torch.tensor([sigma])
171172

172-
c_noise = 0.25 * torch.log(sigma)
173+
if self.config.use_flow_sigmas:
174+
c_noise = sigma / (sigma + 1)
175+
else:
176+
c_noise = 0.25 * torch.log(sigma)
173177

174178
return c_noise
175179

176180
def precondition_outputs(self, sample, model_output, sigma):
177-
sigma_data = self.config.sigma_data
178-
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
179-
180-
if self.config.prediction_type == "epsilon":
181-
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
182-
elif self.config.prediction_type == "v_prediction":
183-
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
181+
if self.config.use_flow_sigmas:
182+
return self._precondition_outputs_flow(sample, model_output, sigma)
184183
else:
185-
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
186-
187-
denoised = c_skip * sample + c_out * model_output
188-
189-
return denoised
184+
return self._precondition_outputs_edm(sample, model_output, sigma)
190185

191186
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
192187
"""
@@ -441,8 +436,40 @@ def add_noise(
441436
return noisy_samples
442437

443438
def _get_conditioning_c_in(self, sigma):
444-
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
439+
if self.config.use_flow_sigmas:
440+
t = sigma / (sigma + 1)
441+
c_in = 1.0 - t
442+
else:
443+
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
445444
return c_in
446445

446+
def _precondition_outputs_flow(self, sample, model_output, sigma):
447+
t = sigma / (sigma + 1)
448+
c_skip = 1.0 - t
449+
450+
if self.config.prediction_type == "epsilon":
451+
c_out = -t
452+
elif self.config.prediction_type == "v_prediction":
453+
c_out = t
454+
else:
455+
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
456+
457+
denoised = c_skip * sample + c_out * model_output
458+
return denoised
459+
460+
def _precondition_outputs_edm(self, sample, model_output, sigma):
461+
sigma_data = self.config.sigma_data
462+
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
463+
464+
if self.config.prediction_type == "epsilon":
465+
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
466+
elif self.config.prediction_type == "v_prediction":
467+
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
468+
else:
469+
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
470+
471+
denoised = c_skip * sample + c_out * model_output
472+
return denoised
473+
447474
def __len__(self):
448475
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)