Skip to content

Commit 4d90851

Browse files
committed
support text-to-image
1 parent 73a9d58 commit 4d90851

File tree

11 files changed

+1281
-49
lines changed

11 files changed

+1281
-49
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 144 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
from huggingface_hub import snapshot_download
88
from transformers import T5EncoderModel, T5TokenizerFast
99

10-
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
10+
from diffusers import (
11+
AutoencoderKLCosmos,
12+
AutoencoderKLWan,
13+
CosmosTextToImagePipeline,
14+
CosmosTextToWorldPipeline,
15+
CosmosTransformer3DModel,
16+
EDMEulerScheduler,
17+
FlowMatchEulerEDMCosmos2_0Scheduler,
18+
)
1119

1220

1321
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -29,7 +37,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
2937
state_dict[new_key] = state_dict.pop(key)
3038

3139

32-
TRANSFORMER_KEYS_RENAME_DICT = {
40+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3341
"t_embedder.1": "time_embed.t_embedder",
3442
"affline_norm": "time_embed.norm",
3543
".blocks.0.block.attn": ".attn1",
@@ -56,14 +64,53 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
5664
"final_layer.linear": "proj_out",
5765
}
5866

59-
TRANSFORMER_SPECIAL_KEYS_REMAP = {
67+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
6068
"blocks.block": rename_transformer_blocks_,
6169
"logvar.0.freqs": remove_keys_,
6270
"logvar.0.phases": remove_keys_,
6371
"logvar.1.weight": remove_keys_,
6472
"pos_embedder.seq": remove_keys_,
6573
}
6674

75+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
76+
"t_embedder.1": "time_embed.t_embedder",
77+
"t_embedding_norm": "time_embed.norm",
78+
"blocks": "transformer_blocks",
79+
"adaln_modulation_self_attn.1": "norm1.linear_1",
80+
"adaln_modulation_self_attn.2": "norm1.linear_2",
81+
"adaln_modulation_cross_attn.1": "norm2.linear_1",
82+
"adaln_modulation_cross_attn.2": "norm2.linear_2",
83+
"adaln_modulation_mlp.1": "norm3.linear_1",
84+
"adaln_modulation_mlp.2": "norm3.linear_2",
85+
"self_attn": "attn1",
86+
"cross_attn": "attn2",
87+
"q_proj": "to_q",
88+
"k_proj": "to_k",
89+
"v_proj": "to_v",
90+
"output_proj": "to_out.0",
91+
"q_norm": "norm_q",
92+
"k_norm": "norm_k",
93+
"mlp.layer1": "ff.net.0.proj",
94+
"mlp.layer2": "ff.net.2",
95+
"x_embedder.proj.1": "patch_embed.proj",
96+
# "extra_pos_embedder": "learnable_pos_embed",
97+
"final_layer.adaln_modulation.1": "norm_out.linear_1",
98+
"final_layer.adaln_modulation.2": "norm_out.linear_2",
99+
"final_layer.linear": "proj_out",
100+
}
101+
102+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
103+
"accum_video_sample_counter": remove_keys_,
104+
"accum_image_sample_counter": remove_keys_,
105+
"accum_iteration": remove_keys_,
106+
"accum_train_in_hours": remove_keys_,
107+
"pos_embedder.seq": remove_keys_,
108+
"pos_embedder.dim_spatial_range": remove_keys_,
109+
"pos_embedder.dim_temporal_range": remove_keys_,
110+
"_extra_state": remove_keys_,
111+
}
112+
113+
67114
TRANSFORMER_CONFIGS = {
68115
"Cosmos-1.0-Diffusion-7B-Text2World": {
69116
"in_channels": 16,
@@ -125,6 +172,21 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
125172
"concat_padding_mask": True,
126173
"extra_pos_embed_type": "learnable",
127174
},
175+
"Cosmos-2.0-Diffusion-2B-Text2Image": {
176+
"in_channels": 16,
177+
"out_channels": 16,
178+
"num_attention_heads": 16,
179+
"attention_head_dim": 128,
180+
"num_layers": 28,
181+
"mlp_ratio": 4.0,
182+
"text_embed_dim": 1024,
183+
"adaln_lora_dim": 256,
184+
"max_size": (128, 240, 240),
185+
"patch_size": (1, 2, 2),
186+
"rope_scale": (1.0, 1.0, 1.0),
187+
"concat_padding_mask": True,
188+
"extra_pos_embed_type": None,
189+
},
128190
}
129191

130192
VAE_KEYS_RENAME_DICT = {
@@ -216,9 +278,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
216278
return state_dict
217279

218280

219-
def convert_transformer(transformer_type: str, ckpt_path: str):
281+
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
220282
PREFIX_KEY = "net."
221-
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
283+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
284+
285+
if "Cosmos-1.0" in transformer_type:
286+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
287+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
288+
elif "Cosmos-2.0" in transformer_type:
289+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
290+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
291+
else:
292+
assert False
222293

223294
with init_empty_weights():
224295
config = TRANSFORMER_CONFIGS[transformer_type]
@@ -281,13 +352,66 @@ def convert_vae(vae_type: str):
281352
return vae
282353

283354

355+
def save_pipeline_cosmos_1_0(args, transformer, vae, dtype):
356+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
357+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
358+
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
359+
# So, the sigma_min values that is used is the default value of 0.002.
360+
scheduler = EDMEulerScheduler(
361+
sigma_min=0.002,
362+
sigma_max=80,
363+
sigma_data=0.5,
364+
sigma_schedule="karras",
365+
num_train_timesteps=1000,
366+
prediction_type="epsilon",
367+
rho=7.0,
368+
final_sigmas_type="sigma_min",
369+
)
370+
371+
pipe = CosmosTextToWorldPipeline(
372+
text_encoder=text_encoder,
373+
tokenizer=tokenizer,
374+
transformer=transformer,
375+
vae=vae,
376+
scheduler=scheduler,
377+
)
378+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
379+
380+
381+
def save_pipeline_cosmos_2_0(args, transformer, vae, dtype):
382+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
383+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
384+
385+
scheduler = FlowMatchEulerEDMCosmos2_0Scheduler(
386+
sigma_min=0.0002,
387+
sigma_max=80,
388+
sigma_data=1.0,
389+
sigma_schedule="karras",
390+
num_train_timesteps=1000,
391+
prediction_type="epsilon",
392+
rho=7.0,
393+
final_sigmas_type="sigma_min",
394+
)
395+
396+
pipe = CosmosTextToImagePipeline(
397+
text_encoder=text_encoder,
398+
tokenizer=tokenizer,
399+
transformer=transformer,
400+
vae=vae,
401+
scheduler=scheduler,
402+
)
403+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
404+
405+
284406
def get_args():
285407
parser = argparse.ArgumentParser()
286408
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
287409
parser.add_argument(
288410
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
289411
)
290-
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
412+
parser.add_argument(
413+
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
414+
)
291415
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
292416
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
293417
parser.add_argument("--save_pipeline", action="store_true")
@@ -316,37 +440,26 @@ def get_args():
316440
assert args.tokenizer_path is not None
317441

318442
if args.transformer_ckpt_path is not None:
319-
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
443+
weights_only = "Cosmos-1.0" in args.transformer_type
444+
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
320445
transformer = transformer.to(dtype=dtype)
321446
if not args.save_pipeline:
322447
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
323448

324449
if args.vae_type is not None:
325-
vae = convert_vae(args.vae_type)
450+
if "Cosmos-1.0" in args.transformer_type:
451+
vae = convert_vae(args.vae_type)
452+
else:
453+
vae = AutoencoderKLWan.from_pretrained(
454+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
455+
)
326456
if not args.save_pipeline:
327457
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
328458

329459
if args.save_pipeline:
330-
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
331-
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
332-
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
333-
# So, the sigma_min values that is used is the default value of 0.002.
334-
scheduler = EDMEulerScheduler(
335-
sigma_min=0.002,
336-
sigma_max=80,
337-
sigma_data=0.5,
338-
sigma_schedule="karras",
339-
num_train_timesteps=1000,
340-
prediction_type="epsilon",
341-
rho=7.0,
342-
final_sigmas_type="sigma_min",
343-
)
344-
345-
pipe = CosmosTextToWorldPipeline(
346-
text_encoder=text_encoder,
347-
tokenizer=tokenizer,
348-
transformer=transformer,
349-
vae=vae,
350-
scheduler=scheduler,
351-
)
352-
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
460+
if "Cosmos-1.0" in args.transformer_type:
461+
save_pipeline_cosmos_1_0(args, transformer, vae, dtype)
462+
elif "Cosmos-2.0" in args.transformer_type:
463+
save_pipeline_cosmos_2_0(args, transformer, vae, dtype)
464+
else:
465+
assert False

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@
271271
"EulerAncestralDiscreteScheduler",
272272
"EulerDiscreteScheduler",
273273
"FlowMatchEulerDiscreteScheduler",
274+
"FlowMatchEulerEDMCosmos2_0Scheduler",
274275
"FlowMatchHeunDiscreteScheduler",
275276
"FlowMatchLCMScheduler",
276277
"HeunDiscreteScheduler",
@@ -361,6 +362,7 @@
361362
"CogView4ControlPipeline",
362363
"CogView4Pipeline",
363364
"ConsisIDPipeline",
365+
"CosmosTextToImagePipeline",
364366
"CosmosTextToWorldPipeline",
365367
"CosmosVideoToWorldPipeline",
366368
"CycleDiffusionPipeline",
@@ -878,6 +880,7 @@
878880
EulerAncestralDiscreteScheduler,
879881
EulerDiscreteScheduler,
880882
FlowMatchEulerDiscreteScheduler,
883+
FlowMatchEulerEDMCosmos2_0Scheduler,
881884
FlowMatchHeunDiscreteScheduler,
882885
FlowMatchLCMScheduler,
883886
HeunDiscreteScheduler,
@@ -949,6 +952,7 @@
949952
CogView4ControlPipeline,
950953
CogView4Pipeline,
951954
ConsisIDPipeline,
955+
CosmosTextToImagePipeline,
952956
CosmosTextToWorldPipeline,
953957
CosmosVideoToWorldPipeline,
954958
CycleDiffusionPipeline,

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,6 @@ def forward(
544544
hidden_states = self.proj_out(hidden_states)
545545
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
546546
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
547-
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
548-
# Another few hours of sanity lost to the void.
549547
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
550548
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
551549

src/diffusers/pipelines/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@
157157
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
158158
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
159159
_import_structure["consisid"] = ["ConsisIDPipeline"]
160-
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
160+
_import_structure["cosmos"] = [
161+
"CosmosTextToImagePipeline",
162+
"CosmosTextToWorldPipeline",
163+
"CosmosVideoToWorldPipeline",
164+
]
161165
_import_structure["controlnet"].extend(
162166
[
163167
"BlipDiffusionControlNetPipeline",
@@ -559,7 +563,7 @@
559563
StableDiffusionControlNetXSPipeline,
560564
StableDiffusionXLControlNetXSPipeline,
561565
)
562-
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
566+
from .cosmos import CosmosTextToImagePipeline, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
563567
from .deepfloyd_if import (
564568
IFImg2ImgPipeline,
565569
IFImg2ImgSuperResolutionPipeline,

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["pipeline_cosmos_text2image"] = ["CosmosTextToImagePipeline"]
2526
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
2627
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
2728

@@ -33,6 +34,7 @@
3334
except OptionalDependencyNotAvailable:
3435
from ...utils.dummy_torch_and_transformers_objects import *
3536
else:
37+
from .pipeline_cosmos_text2image import CosmosTextToImagePipeline
3638
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
3739
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
3840

0 commit comments

Comments
 (0)