Skip to content

Commit 2354fda

Browse files
committed
init
1 parent c318686 commit 2354fda

File tree

5 files changed

+1315
-2
lines changed

5 files changed

+1315
-2
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import torch
5+
import safetensors.torch
6+
from accelerate import init_empty_weights
7+
from huggingface_hub import hf_hub_download
8+
9+
from diffusers.utils.import_utils import is_accelerate_available
10+
from diffusers.models import ZImageTransformer2DModel
11+
from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel
12+
13+
"""
14+
python scripts/convert_z_image_controlnet_to_diffusers.py \
15+
--original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \
16+
--original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \
17+
--filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
18+
--output_path "z-image-controlnet-hf/"
19+
"""
20+
21+
22+
CTX = init_empty_weights if is_accelerate_available else nullcontext
23+
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str)
26+
parser.add_argument("--original_controlnet_repo_id", default=None, type=str)
27+
parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str)
28+
parser.add_argument("--checkpoint_path", default=None, type=str)
29+
parser.add_argument("--output_path", type=str)
30+
31+
args = parser.parse_args()
32+
33+
34+
def load_original_checkpoint(args):
35+
if args.original_controlnet_repo_id is not None:
36+
ckpt_path = hf_hub_download(repo_id=args.original_controlnet_repo_id, filename=args.filename)
37+
elif args.checkpoint_path is not None:
38+
ckpt_path = args.checkpoint_path
39+
else:
40+
raise ValueError(" please provide either `original_controlnet_repo_id` or a local `checkpoint_path`")
41+
42+
original_state_dict = safetensors.torch.load_file(ckpt_path)
43+
return original_state_dict
44+
45+
def load_z_image(args):
46+
model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16)
47+
return model.state_dict(), model.config
48+
49+
def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict):
50+
converted_state_dict = {}
51+
52+
converted_state_dict.update(original_state_dict)
53+
54+
to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"}
55+
56+
for key in z_image.keys():
57+
for copy_key in to_copy:
58+
if key.startswith(copy_key):
59+
converted_state_dict[key] = z_image[key]
60+
61+
return converted_state_dict
62+
63+
64+
def main(args):
65+
original_ckpt = load_original_checkpoint(args)
66+
z_image, config = load_z_image(args)
67+
68+
control_in_dim = 16
69+
control_layers_places = [0, 5, 10, 15, 20, 25]
70+
71+
converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt)
72+
73+
for key, tensor in converted_controlnet_state_dict.items():
74+
print(f"{key} - {tensor.dtype}")
75+
76+
controlnet = ZImageControlNetModel(
77+
all_patch_size=config["all_patch_size"],
78+
all_f_patch_size=config["all_f_patch_size"],
79+
in_channels=config["in_channels"],
80+
dim=config["dim"],
81+
n_layers=config["n_layers"],
82+
n_refiner_layers=config["n_refiner_layers"],
83+
n_heads=config["n_heads"],
84+
n_kv_heads=config["n_kv_heads"],
85+
norm_eps=config["norm_eps"],
86+
qk_norm=config["qk_norm"],
87+
cap_feat_dim=config["cap_feat_dim"],
88+
rope_theta=config["rope_theta"],
89+
t_scale=config["t_scale"],
90+
axes_dims=config["axes_dims"],
91+
axes_lens=config["axes_lens"],
92+
control_layers_places=control_layers_places,
93+
control_in_dim=control_in_dim,
94+
)
95+
missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict)
96+
print(f"{missing=}")
97+
print(f"{unexpected=}")
98+
print("Saving Z-Image ControlNet in Diffusers format")
99+
controlnet.save_pretrained(args.output_path, max_shard_size="5GB")
100+
101+
102+
if __name__ == "__main__":
103+
main(args)

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from .controlnet_union import ControlNetUnionModel
2121
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
22+
from .controlnet_z_image import ZImageControlNetModel
2223
from .multicontrolnet import MultiControlNetModel
2324
from .multicontrolnet_union import MultiControlNetUnionModel
2425

0 commit comments

Comments
 (0)