11import argparse
22from contextlib import nullcontext
33
4- import safetensors .torch
54import torch
5+ import safetensors .torch
66from accelerate import init_empty_weights
77from huggingface_hub import hf_hub_download
88
9- from diffusers .models import ZImageTransformer2DModel
109from diffusers .models .controlnets .controlnet_z_image import ZImageControlNetModel
1110from diffusers .utils .import_utils import is_accelerate_available
1211
1312
1413"""
1514python scripts/convert_z_image_controlnet_to_diffusers.py \
16- --original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \
1715 --original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \
1816 --filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
1917--output_path "z-image-controlnet-hf/"
2321CTX = init_empty_weights if is_accelerate_available else nullcontext
2422
2523parser = argparse .ArgumentParser ()
26- parser .add_argument ("--original_z_image_repo_id" , default = "Tongyi-MAI/Z-Image-Turbo" , type = str )
2724parser .add_argument ("--original_controlnet_repo_id" , default = None , type = str )
2825parser .add_argument ("--filename" , default = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" , type = str )
2926parser .add_argument ("--checkpoint_path" , default = None , type = str )
@@ -44,72 +41,29 @@ def load_original_checkpoint(args):
4441 return original_state_dict
4542
4643
47- def load_z_image (args ):
48- model = ZImageTransformer2DModel .from_pretrained (
49- args .original_z_image_repo_id , subfolder = "transformer" , torch_dtype = torch .bfloat16
50- )
51- return model .state_dict (), model .config
52-
53-
54- def convert_z_image_controlnet_checkpoint_to_diffusers (z_image , original_state_dict ):
44+ def convert_z_image_controlnet_checkpoint_to_diffusers (original_state_dict ):
5545 converted_state_dict = {}
5646
5747 converted_state_dict .update (original_state_dict )
5848
59- to_copy = {
60- "all_x_embedder." ,
61- "noise_refiner." ,
62- "context_refiner." ,
63- "t_embedder." ,
64- "cap_embedder." ,
65- "x_pad_token" ,
66- "cap_pad_token" ,
67- }
68-
69- for key in z_image .keys ():
70- for copy_key in to_copy :
71- if key .startswith (copy_key ):
72- converted_state_dict [key ] = z_image [key ]
73-
7449 return converted_state_dict
7550
7651
7752def main (args ):
7853 original_ckpt = load_original_checkpoint (args )
79- z_image , config = load_z_image (args )
8054
8155 control_in_dim = 16
8256 control_layers_places = [0 , 5 , 10 , 15 , 20 , 25 ]
8357
84- converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers (z_image , original_ckpt )
85-
86- for key , tensor in converted_controlnet_state_dict .items ():
87- print (f"{ key } - { tensor .dtype } " )
58+ converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers (original_ckpt )
8859
8960 controlnet = ZImageControlNetModel (
90- all_patch_size = config ["all_patch_size" ],
91- all_f_patch_size = config ["all_f_patch_size" ],
92- in_channels = config ["in_channels" ],
93- dim = config ["dim" ],
94- n_layers = config ["n_layers" ],
95- n_refiner_layers = config ["n_refiner_layers" ],
96- n_heads = config ["n_heads" ],
97- n_kv_heads = config ["n_kv_heads" ],
98- norm_eps = config ["norm_eps" ],
99- qk_norm = config ["qk_norm" ],
100- cap_feat_dim = config ["cap_feat_dim" ],
101- rope_theta = config ["rope_theta" ],
102- t_scale = config ["t_scale" ],
103- axes_dims = config ["axes_dims" ],
104- axes_lens = config ["axes_lens" ],
10561 control_layers_places = control_layers_places ,
10662 control_in_dim = control_in_dim ,
107- )
108- missing , unexpected = controlnet .load_state_dict (converted_controlnet_state_dict )
109- print (f"{ missing = } " )
110- print (f"{ unexpected = } " )
63+ ).to (torch .bfloat16 )
64+ controlnet .load_state_dict (converted_controlnet_state_dict )
11165 print ("Saving Z-Image ControlNet in Diffusers format" )
112- controlnet .save_pretrained (args .output_path , max_shard_size = "5GB" )
66+ controlnet .save_pretrained (args .output_path )
11367
11468
11569if __name__ == "__main__" :
0 commit comments