Skip to content

Commit a63e543

Browse files
committed
add pipeline
1 parent 88faab1 commit a63e543

File tree

6 files changed

+877
-0
lines changed

6 files changed

+877
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import argparse
2+
from typing import Any, Dict
3+
4+
import torch
5+
from accelerate import init_empty_weights
6+
7+
from diffusers import CosmosTransformer3DModel
8+
9+
10+
def remove_keys_(key: str, state_dict: Dict[str, Any]):
11+
state_dict.pop(key)
12+
13+
14+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
15+
state_dict[new_key] = state_dict.pop(old_key)
16+
17+
18+
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
19+
block_index = int(key.split(".")[1].removeprefix("block"))
20+
new_key = key
21+
22+
old_prefix = f"blocks.block{block_index}"
23+
new_prefix = f"transformer_blocks.{block_index}"
24+
new_key = new_prefix + new_key.removeprefix(old_prefix)
25+
26+
state_dict[new_key] = state_dict.pop(key)
27+
28+
29+
TRANSFORMER_KEYS_RENAME_DICT = {
30+
"t_embedder.1": "time_embed.t_embedder",
31+
"affline_norm": "time_embed.norm",
32+
".blocks.0.block.attn": ".attn1",
33+
".blocks.1.block.attn": ".attn2",
34+
".blocks.2.block": ".ff",
35+
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
36+
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
37+
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
38+
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
39+
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
40+
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
41+
"to_q.0": "to_q",
42+
"to_q.1": "norm_q",
43+
"to_k.0": "to_k",
44+
"to_k.1": "norm_k",
45+
"to_v.0": "to_v",
46+
"layer1": "net.0.proj",
47+
"layer2": "net.2",
48+
"proj.1": "proj",
49+
"x_embedder": "patch_embed",
50+
"extra_pos_embedder": "learnable_pos_embed",
51+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
52+
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
53+
"final_layer.linear": "proj_out",
54+
}
55+
56+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
57+
"blocks.block": rename_transformer_blocks_,
58+
"logvar.0.freqs": remove_keys_,
59+
"logvar.0.phases": remove_keys_,
60+
"logvar.1.weight": remove_keys_,
61+
"pos_embedder.seq": remove_keys_,
62+
}
63+
64+
VAE_KEYS_RENAME_DICT = {}
65+
66+
VAE_SPECIAL_KEYS_REMAP = {}
67+
68+
69+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
70+
state_dict = saved_dict
71+
if "model" in saved_dict.keys():
72+
state_dict = state_dict["model"]
73+
if "module" in saved_dict.keys():
74+
state_dict = state_dict["module"]
75+
if "state_dict" in saved_dict.keys():
76+
state_dict = state_dict["state_dict"]
77+
return state_dict
78+
79+
80+
def convert_transformer(ckpt_path: str):
81+
PREFIX_KEY = "net."
82+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
83+
84+
with init_empty_weights():
85+
transformer = CosmosTransformer3DModel()
86+
87+
for key in list(original_state_dict.keys()):
88+
new_key = key[:]
89+
if new_key.startswith(PREFIX_KEY):
90+
new_key = new_key.removeprefix(PREFIX_KEY)
91+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
92+
new_key = new_key.replace(replace_key, rename_key)
93+
update_state_dict_(original_state_dict, key, new_key)
94+
95+
for key in list(original_state_dict.keys()):
96+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
97+
if special_key not in key:
98+
continue
99+
handler_fn_inplace(key, original_state_dict)
100+
101+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
102+
return transformer
103+
104+
105+
# def convert_vae(ckpt_path: str):
106+
# original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
107+
108+
# with init_empty_weights():
109+
# vae = AutoencoderKLHunyuanVideo()
110+
111+
# for key in list(original_state_dict.keys()):
112+
# new_key = key[:]
113+
# for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
114+
# new_key = new_key.replace(replace_key, rename_key)
115+
# update_state_dict_(original_state_dict, key, new_key)
116+
117+
# for key in list(original_state_dict.keys()):
118+
# for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
119+
# if special_key not in key:
120+
# continue
121+
# handler_fn_inplace(key, original_state_dict)
122+
123+
# vae.load_state_dict(original_state_dict, strict=True, assign=True)
124+
# return vae
125+
126+
127+
def get_args():
128+
parser = argparse.ArgumentParser()
129+
parser.add_argument(
130+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
131+
)
132+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
133+
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
134+
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
135+
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
136+
parser.add_argument("--save_pipeline", action="store_true")
137+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
138+
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
139+
return parser.parse_args()
140+
141+
142+
DTYPE_MAPPING = {
143+
"fp32": torch.float32,
144+
"fp16": torch.float16,
145+
"bf16": torch.bfloat16,
146+
}
147+
148+
149+
if __name__ == "__main__":
150+
args = get_args()
151+
152+
transformer = None
153+
dtype = DTYPE_MAPPING[args.dtype]
154+
155+
if args.save_pipeline:
156+
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
157+
assert args.text_encoder_path is not None
158+
assert args.tokenizer_path is not None
159+
assert args.text_encoder_2_path is not None
160+
161+
if args.transformer_ckpt_path is not None:
162+
transformer = convert_transformer(args.transformer_ckpt_path)
163+
transformer = transformer.to(dtype=dtype)
164+
if not args.save_pipeline:
165+
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
166+
167+
# if args.vae_ckpt_path is not None:
168+
# vae = convert_vae(args.vae_ckpt_path)
169+
# if not args.save_pipeline:
170+
# vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
171+
172+
# if args.save_pipeline:
173+
# text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
174+
# tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
175+
# text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
176+
# tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
177+
# scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
178+
179+
# pipe = CosmosPipeline(
180+
# transformer=transformer,
181+
# vae=vae,
182+
# text_encoder=text_encoder,
183+
# tokenizer=tokenizer,
184+
# text_encoder_2=text_encoder_2,
185+
# tokenizer_2=tokenizer_2,
186+
# scheduler=scheduler,
187+
# )
188+
# pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@
287287
"CogVideoXVideoToVideoPipeline",
288288
"CogView3PlusPipeline",
289289
"ConsisIDPipeline",
290+
"CosmosPipeline",
290291
"CycleDiffusionPipeline",
291292
"FluxControlImg2ImgPipeline",
292293
"FluxControlInpaintPipeline",
@@ -781,6 +782,7 @@
781782
CogVideoXVideoToVideoPipeline,
782783
CogView3PlusPipeline,
783784
ConsisIDPipeline,
785+
CosmosPipeline,
784786
CycleDiffusionPipeline,
785787
FluxControlImg2ImgPipeline,
786788
FluxControlInpaintPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
]
156156
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
157157
_import_structure["consisid"] = ["ConsisIDPipeline"]
158+
_import_structure["cosmos"] = ["CosmosPipeline"]
158159
_import_structure["controlnet"].extend(
159160
[
160161
"BlipDiffusionControlNetPipeline",
@@ -518,6 +519,7 @@
518519
StableDiffusionControlNetXSPipeline,
519520
StableDiffusionXLControlNetXSPipeline,
520521
)
522+
from .cosmos import CosmosPipeline
521523
from .deepfloyd_if import (
522524
IFImg2ImgPipeline,
523525
IFImg2ImgSuperResolutionPipeline,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
17+
try:
18+
if not (is_transformers_available() and is_torch_available()):
19+
raise OptionalDependencyNotAvailable()
20+
except OptionalDependencyNotAvailable:
21+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
22+
23+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24+
else:
25+
_import_structure["pipeline_cosmos"] = ["CosmosPipeline"]
26+
27+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28+
try:
29+
if not (is_transformers_available() and is_torch_available()):
30+
raise OptionalDependencyNotAvailable()
31+
32+
except OptionalDependencyNotAvailable:
33+
from ...utils.dummy_torch_and_transformers_objects import *
34+
else:
35+
from .pipeline_cosmos import CosmosPipeline
36+
37+
else:
38+
import sys
39+
40+
sys.modules[__name__] = _LazyModule(
41+
__name__,
42+
globals()["__file__"],
43+
_import_structure,
44+
module_spec=__spec__,
45+
)
46+
47+
for name, value in _dummy_objects.items():
48+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)