Skip to content

Commit dd9775c

Browse files
committed
pop control_noise_refiner from 2.0 state_dict
1 parent f4b7fcc commit dd9775c

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
convert_stable_cascade_unet_single_file_to_diffusers,
5050
convert_wan_transformer_to_diffusers,
5151
convert_wan_vae_to_diffusers,
52+
convert_z_image_controlnet_checkpoint_to_diffusers,
5253
convert_z_image_transformer_checkpoint_to_diffusers,
5354
create_controlnet_diffusers_config_from_ldm,
5455
create_unet_diffusers_config_from_ldm,
@@ -174,14 +175,18 @@
174175
"default_subfolder": "transformer",
175176
},
176177
"ZImageControlNetModel": {
177-
"checkpoint_mapping_fn": lambda x: x,
178+
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
178179
"config_create_fn": create_z_image_controlnet_config,
179180
},
180181
}
181182

182183

183184
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
184-
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
185+
model_state_dict_keys = set(model_state_dict.keys())
186+
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
187+
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
188+
is_match = model_state_dict_keys == checkpoint_state_dict_keys
189+
return not (is_subset and is_match)
185190

186191

187192
def _get_single_file_loadable_mapping_class(cls):

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3925,3 +3925,16 @@ def create_z_image_controlnet_config(checkpoint, **kwargs):
39253925
return v2_config
39263926
else:
39273927
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
3928+
3929+
3930+
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, **kwargs):
3931+
control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1]
3932+
if control_x_embedder_weight_shape == 64:
3933+
return checkpoint
3934+
elif control_x_embedder_weight_shape == 132:
3935+
converted_state_dict = {
3936+
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
3937+
}
3938+
return converted_state_dict
3939+
else:
3940+
raise ValueError("Unknown Z-Image Turbo ControlNet type.")

src/diffusers/models/controlnets/controlnet_z_image.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -432,21 +432,7 @@ def __init__(
432432

433433
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
434434
if self.add_control_noise_refiner:
435-
self.control_noise_refiner = nn.ModuleList(
436-
[
437-
ZImageControlTransformerBlock(
438-
1000 + layer_id,
439-
dim,
440-
n_heads,
441-
n_kv_heads,
442-
norm_eps,
443-
qk_norm,
444-
modulation=True,
445-
block_id=layer_id,
446-
)
447-
for layer_id in range(n_refiner_layers)
448-
]
449-
)
435+
self.control_noise_refiner = None
450436
else:
451437
self.control_noise_refiner = nn.ModuleList(
452438
[

0 commit comments

Comments
 (0)