Skip to content

Commit 52f996e

Browse files
committed
convert passed
1 parent 0c30839 commit 52f996e

File tree

1 file changed

+6
-52
lines changed

1 file changed

+6
-52
lines changed

scripts/convert_z_image_controlnet_to_diffusers.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import argparse
22
from contextlib import nullcontext
33

4-
import safetensors.torch
54
import torch
5+
import safetensors.torch
66
from accelerate import init_empty_weights
77
from huggingface_hub import hf_hub_download
88

9-
from diffusers.models import ZImageTransformer2DModel
109
from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel
1110
from diffusers.utils.import_utils import is_accelerate_available
1211

1312

1413
"""
1514
python scripts/convert_z_image_controlnet_to_diffusers.py \
16-
--original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \
1715
--original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \
1816
--filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
1917
--output_path "z-image-controlnet-hf/"
@@ -23,7 +21,6 @@
2321
CTX = init_empty_weights if is_accelerate_available else nullcontext
2422

2523
parser = argparse.ArgumentParser()
26-
parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str)
2724
parser.add_argument("--original_controlnet_repo_id", default=None, type=str)
2825
parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str)
2926
parser.add_argument("--checkpoint_path", default=None, type=str)
@@ -44,72 +41,29 @@ def load_original_checkpoint(args):
4441
return original_state_dict
4542

4643

47-
def load_z_image(args):
48-
model = ZImageTransformer2DModel.from_pretrained(
49-
args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16
50-
)
51-
return model.state_dict(), model.config
52-
53-
54-
def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict):
44+
def convert_z_image_controlnet_checkpoint_to_diffusers(original_state_dict):
5545
converted_state_dict = {}
5646

5747
converted_state_dict.update(original_state_dict)
5848

59-
to_copy = {
60-
"all_x_embedder.",
61-
"noise_refiner.",
62-
"context_refiner.",
63-
"t_embedder.",
64-
"cap_embedder.",
65-
"x_pad_token",
66-
"cap_pad_token",
67-
}
68-
69-
for key in z_image.keys():
70-
for copy_key in to_copy:
71-
if key.startswith(copy_key):
72-
converted_state_dict[key] = z_image[key]
73-
7449
return converted_state_dict
7550

7651

7752
def main(args):
7853
original_ckpt = load_original_checkpoint(args)
79-
z_image, config = load_z_image(args)
8054

8155
control_in_dim = 16
8256
control_layers_places = [0, 5, 10, 15, 20, 25]
8357

84-
converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt)
85-
86-
for key, tensor in converted_controlnet_state_dict.items():
87-
print(f"{key} - {tensor.dtype}")
58+
converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(original_ckpt)
8859

8960
controlnet = ZImageControlNetModel(
90-
all_patch_size=config["all_patch_size"],
91-
all_f_patch_size=config["all_f_patch_size"],
92-
in_channels=config["in_channels"],
93-
dim=config["dim"],
94-
n_layers=config["n_layers"],
95-
n_refiner_layers=config["n_refiner_layers"],
96-
n_heads=config["n_heads"],
97-
n_kv_heads=config["n_kv_heads"],
98-
norm_eps=config["norm_eps"],
99-
qk_norm=config["qk_norm"],
100-
cap_feat_dim=config["cap_feat_dim"],
101-
rope_theta=config["rope_theta"],
102-
t_scale=config["t_scale"],
103-
axes_dims=config["axes_dims"],
104-
axes_lens=config["axes_lens"],
10561
control_layers_places=control_layers_places,
10662
control_in_dim=control_in_dim,
107-
)
108-
missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict)
109-
print(f"{missing=}")
110-
print(f"{unexpected=}")
63+
).to(torch.bfloat16)
64+
controlnet.load_state_dict(converted_controlnet_state_dict)
11165
print("Saving Z-Image ControlNet in Diffusers format")
112-
controlnet.save_pretrained(args.output_path, max_shard_size="5GB")
66+
controlnet.save_pretrained(args.output_path)
11367

11468

11569
if __name__ == "__main__":

0 commit comments

Comments
 (0)