Skip to content

Commit 7dae1c6

Browse files
authored
Merge branch 'main' into add-caching-to-skyreels-v2-pipelines
2 parents cf0ea70 + f7753b1 commit 7dae1c6

36 files changed

+5470
-896
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ output.save("output.png")
7070
- all
7171
- __call__
7272

73+
## Cosmos2_5_PredictBasePipeline
74+
75+
[[autodoc]] Cosmos2_5_PredictBasePipeline
76+
- all
77+
- __call__
78+
7379
## CosmosPipelineOutput
7480

7581
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,55 @@
1+
"""
2+
# Cosmos 2 Predict
3+
4+
Download checkpoint
5+
```bash
6+
hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+
```
8+
9+
convert checkpoint
10+
```bash
11+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+
python scripts/convert_cosmos_to_diffusers.py \
14+
--transformer_ckpt_path $transformer_ckpt_path \
15+
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+
--text_encoder_path google-t5/t5-11b \
17+
--tokenizer_path google-t5/t5-11b \
18+
--vae_type wan2.1 \
19+
--output_path converted/cosmos-p2-t2i-2b \
20+
--save_pipeline
21+
```
22+
23+
# Cosmos 2.5 Predict
24+
25+
Download checkpoint
26+
```bash
27+
hf download nvidia/Cosmos-Predict2.5-2B
28+
```
29+
30+
Convert checkpoint
31+
```bash
32+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
33+
34+
python scripts/convert_cosmos_to_diffusers.py \
35+
--transformer_type Cosmos-2.5-Predict-Base-2B \
36+
--transformer_ckpt_path $transformer_ckpt_path \
37+
--vae_type wan2.1 \
38+
--output_path converted/cosmos-p2.5-base-2b \
39+
--save_pipeline
40+
```
41+
42+
"""
43+
144
import argparse
245
import pathlib
46+
import sys
347
from typing import Any, Dict
448

549
import torch
650
from accelerate import init_empty_weights
751
from huggingface_hub import snapshot_download
8-
from transformers import T5EncoderModel, T5TokenizerFast
52+
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
953

1054
from diffusers import (
1155
AutoencoderKLCosmos,
@@ -17,7 +61,9 @@
1761
CosmosVideoToWorldPipeline,
1862
EDMEulerScheduler,
1963
FlowMatchEulerDiscreteScheduler,
64+
UniPCMultistepScheduler,
2065
)
66+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
2167

2268

2369
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -233,6 +279,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233279
"concat_padding_mask": True,
234280
"extra_pos_embed_type": None,
235281
},
282+
"Cosmos-2.5-Predict-Base-2B": {
283+
"in_channels": 16 + 1,
284+
"out_channels": 16,
285+
"num_attention_heads": 16,
286+
"attention_head_dim": 128,
287+
"num_layers": 28,
288+
"mlp_ratio": 4.0,
289+
"text_embed_dim": 1024,
290+
"adaln_lora_dim": 256,
291+
"max_size": (128, 240, 240),
292+
"patch_size": (1, 2, 2),
293+
"rope_scale": (1.0, 3.0, 3.0),
294+
"concat_padding_mask": True,
295+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
296+
"extra_pos_embed_type": None,
297+
"use_crossattn_projection": True,
298+
"crossattn_proj_in_channels": 100352,
299+
"encoder_hidden_states_channels": 1024,
300+
},
236301
}
237302

238303
VAE_KEYS_RENAME_DICT = {
@@ -334,6 +399,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334399
elif "Cosmos-2.0" in transformer_type:
335400
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336401
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
402+
elif "Cosmos-2.5" in transformer_type:
403+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
404+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337405
else:
338406
assert False
339407

@@ -347,6 +415,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347415
new_key = new_key.removeprefix(PREFIX_KEY)
348416
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
349417
new_key = new_key.replace(replace_key, rename_key)
418+
print(key, "->", new_key, flush=True)
350419
update_state_dict_(original_state_dict, key, new_key)
351420

352421
for key in list(original_state_dict.keys()):
@@ -355,6 +424,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355424
continue
356425
handler_fn_inplace(key, original_state_dict)
357426

427+
expected_keys = set(transformer.state_dict().keys())
428+
mapped_keys = set(original_state_dict.keys())
429+
missing_keys = expected_keys - mapped_keys
430+
unexpected_keys = mapped_keys - expected_keys
431+
if missing_keys:
432+
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
433+
for k in missing_keys:
434+
print(k)
435+
sys.exit(1)
436+
if unexpected_keys:
437+
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
438+
for k in unexpected_keys:
439+
print(k)
440+
sys.exit(2)
441+
358442
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
359443
return transformer
360444

@@ -444,17 +528,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444528
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
445529

446530

531+
def save_pipeline_cosmos2_5(args, transformer, vae):
532+
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
533+
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
534+
535+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
536+
text_encoder_path, torch_dtype="auto", device_map="cpu"
537+
)
538+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
539+
540+
scheduler = UniPCMultistepScheduler(
541+
use_karras_sigmas=True,
542+
use_flow_sigmas=True,
543+
prediction_type="flow_prediction",
544+
sigma_max=200.0,
545+
sigma_min=0.01,
546+
)
547+
548+
pipe = Cosmos2_5_PredictBasePipeline(
549+
text_encoder=text_encoder,
550+
tokenizer=tokenizer,
551+
transformer=transformer,
552+
vae=vae,
553+
scheduler=scheduler,
554+
safety_checker=lambda *args, **kwargs: None,
555+
)
556+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
557+
558+
447559
def get_args():
448560
parser = argparse.ArgumentParser()
449561
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
450562
parser.add_argument(
451563
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
452564
)
453565
parser.add_argument(
454-
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
566+
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
455567
)
456-
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
457-
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
568+
parser.add_argument("--text_encoder_path", type=str, default=None)
569+
parser.add_argument("--tokenizer_path", type=str, default=None)
458570
parser.add_argument("--save_pipeline", action="store_true")
459571
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
460572
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -477,8 +589,6 @@ def get_args():
477589
if args.save_pipeline:
478590
assert args.transformer_ckpt_path is not None
479591
assert args.vae_type is not None
480-
assert args.text_encoder_path is not None
481-
assert args.tokenizer_path is not None
482592

483593
if args.transformer_ckpt_path is not None:
484594
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -490,17 +600,26 @@ def get_args():
490600
if args.vae_type is not None:
491601
if "Cosmos-1.0" in args.transformer_type:
492602
vae = convert_vae(args.vae_type)
493-
else:
603+
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
494604
vae = AutoencoderKLWan.from_pretrained(
495605
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
496606
)
607+
else:
608+
raise AssertionError(f"{args.transformer_type} not supported")
609+
497610
if not args.save_pipeline:
498611
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
499612

500613
if args.save_pipeline:
501614
if "Cosmos-1.0" in args.transformer_type:
615+
assert args.text_encoder_path is not None
616+
assert args.tokenizer_path is not None
502617
save_pipeline_cosmos_1_0(args, transformer, vae)
503618
elif "Cosmos-2.0" in args.transformer_type:
619+
assert args.text_encoder_path is not None
620+
assert args.tokenizer_path is not None
504621
save_pipeline_cosmos_2_0(args, transformer, vae)
622+
elif "Cosmos-2.5" in args.transformer_type:
623+
save_pipeline_cosmos2_5(args, transformer, vae)
505624
else:
506-
assert False
625+
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@
279279
"WanAnimateTransformer3DModel",
280280
"WanTransformer3DModel",
281281
"WanVACETransformer3DModel",
282+
"ZImageControlNetModel",
282283
"ZImageTransformer2DModel",
283284
"attention_backend",
284285
]
@@ -462,6 +463,7 @@
462463
"CogView4ControlPipeline",
463464
"CogView4Pipeline",
464465
"ConsisIDPipeline",
466+
"Cosmos2_5_PredictBasePipeline",
465467
"Cosmos2TextToImagePipeline",
466468
"Cosmos2VideoToWorldPipeline",
467469
"CosmosTextToWorldPipeline",
@@ -564,6 +566,7 @@
564566
"QwenImageEditPlusPipeline",
565567
"QwenImageImg2ImgPipeline",
566568
"QwenImageInpaintPipeline",
569+
"QwenImageLayeredPipeline",
567570
"QwenImagePipeline",
568571
"ReduxImageEncoder",
569572
"SanaControlNetPipeline",
@@ -669,6 +672,8 @@
669672
"WuerstchenCombinedPipeline",
670673
"WuerstchenDecoderPipeline",
671674
"WuerstchenPriorPipeline",
675+
"ZImageControlNetInpaintPipeline",
676+
"ZImageControlNetPipeline",
672677
"ZImageImg2ImgPipeline",
673678
"ZImagePipeline",
674679
]
@@ -1016,6 +1021,7 @@
10161021
WanAnimateTransformer3DModel,
10171022
WanTransformer3DModel,
10181023
WanVACETransformer3DModel,
1024+
ZImageControlNetModel,
10191025
ZImageTransformer2DModel,
10201026
attention_backend,
10211027
)
@@ -1170,6 +1176,7 @@
11701176
CogView4ControlPipeline,
11711177
CogView4Pipeline,
11721178
ConsisIDPipeline,
1179+
Cosmos2_5_PredictBasePipeline,
11731180
Cosmos2TextToImagePipeline,
11741181
Cosmos2VideoToWorldPipeline,
11751182
CosmosTextToWorldPipeline,
@@ -1272,6 +1279,7 @@
12721279
QwenImageEditPlusPipeline,
12731280
QwenImageImg2ImgPipeline,
12741281
QwenImageInpaintPipeline,
1282+
QwenImageLayeredPipeline,
12751283
QwenImagePipeline,
12761284
ReduxImageEncoder,
12771285
SanaControlNetPipeline,
@@ -1375,6 +1383,8 @@
13751383
WuerstchenCombinedPipeline,
13761384
WuerstchenDecoderPipeline,
13771385
WuerstchenPriorPipeline,
1386+
ZImageControlNetInpaintPipeline,
1387+
ZImageControlNetPipeline,
13781388
ZImageImg2ImgPipeline,
13791389
ZImagePipeline,
13801390
)

src/diffusers/loaders/single_file_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
convert_stable_cascade_unet_single_file_to_diffusers,
5050
convert_wan_transformer_to_diffusers,
5151
convert_wan_vae_to_diffusers,
52+
convert_z_image_controlnet_checkpoint_to_diffusers,
5253
convert_z_image_transformer_checkpoint_to_diffusers,
5354
create_controlnet_diffusers_config_from_ldm,
5455
create_unet_diffusers_config_from_ldm,
@@ -172,11 +173,18 @@
172173
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
173174
"default_subfolder": "transformer",
174175
},
176+
"ZImageControlNetModel": {
177+
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
178+
},
175179
}
176180

177181

178182
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
179-
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
183+
model_state_dict_keys = set(model_state_dict.keys())
184+
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
185+
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
186+
is_match = model_state_dict_keys == checkpoint_state_dict_keys
187+
return not (is_subset and is_match)
180188

181189

182190
def _get_single_file_loadable_mapping_class(cls):

src/diffusers/loaders/single_file_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@
121121
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
122122
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
123123
"z-image-turbo": "cap_embedder.0.weight",
124+
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
125+
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
124126
"sana": [
125127
"blocks.0.cross_attn.q_linear.weight",
126128
"blocks.0.cross_attn.q_linear.bias",
@@ -220,6 +222,8 @@
220222
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
221223
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
222224
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
225+
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
226+
"z-image-turbo-controlnet-2.x": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
223227
}
224228

225229
# Use to configure model sample size when original config is provided
@@ -779,6 +783,12 @@ def infer_diffusers_model_type(checkpoint):
779783
else:
780784
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
781785

786+
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
787+
model_type = "z-image-turbo-controlnet-2.x"
788+
789+
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
790+
model_type = "z-image-turbo-controlnet"
791+
782792
else:
783793
model_type = "v1"
784794

@@ -3885,3 +3895,17 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str)
38853895
handler_fn_inplace(key, converted_state_dict)
38863896

38873897
return converted_state_dict
3898+
3899+
3900+
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
3901+
if config["add_control_noise_refiner"] is None:
3902+
return checkpoint
3903+
elif config["add_control_noise_refiner"] == "control_noise_refiner":
3904+
return checkpoint
3905+
elif config["add_control_noise_refiner"] == "control_layers":
3906+
converted_state_dict = {
3907+
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
3908+
}
3909+
return converted_state_dict
3910+
else:
3911+
raise ValueError("Unknown Z-Image Turbo ControlNet type.")

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
6767
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
6868
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
69+
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
6970
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
7071
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
7172
_import_structure["embeddings"] = ["ImageProjection"]
@@ -181,6 +182,7 @@
181182
SD3MultiControlNetModel,
182183
SparseControlNetModel,
183184
UNetControlNetXSModel,
185+
ZImageControlNetModel,
184186
)
185187
from .embeddings import ImageProjection
186188
from .modeling_utils import ModelMixin

0 commit comments

Comments
 (0)