Skip to content

Commit 0c30839

Browse files
committed
ruff
1 parent 1e2009d commit 0c30839

File tree

4 files changed

+36
-26
lines changed

4 files changed

+36
-26
lines changed

scripts/convert_z_image_controlnet_to_diffusers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import argparse
22
from contextlib import nullcontext
33

4-
import torch
54
import safetensors.torch
5+
import torch
66
from accelerate import init_empty_weights
77
from huggingface_hub import hf_hub_download
88

9-
from diffusers.utils.import_utils import is_accelerate_available
109
from diffusers.models import ZImageTransformer2DModel
1110
from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel
11+
from diffusers.utils.import_utils import is_accelerate_available
12+
1213

1314
"""
1415
python scripts/convert_z_image_controlnet_to_diffusers.py \
@@ -42,16 +43,28 @@ def load_original_checkpoint(args):
4243
original_state_dict = safetensors.torch.load_file(ckpt_path)
4344
return original_state_dict
4445

46+
4547
def load_z_image(args):
46-
model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16)
48+
model = ZImageTransformer2DModel.from_pretrained(
49+
args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16
50+
)
4751
return model.state_dict(), model.config
4852

53+
4954
def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict):
5055
converted_state_dict = {}
5156

5257
converted_state_dict.update(original_state_dict)
5358

54-
to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"}
59+
to_copy = {
60+
"all_x_embedder.",
61+
"noise_refiner.",
62+
"context_refiner.",
63+
"t_embedder.",
64+
"cap_embedder.",
65+
"x_pad_token",
66+
"cap_pad_token",
67+
}
5568

5669
for key in z_image.keys():
5770
for copy_key in to_copy:

src/diffusers/models/controlnets/controlnet_z_image.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,26 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...models.normalization import RMSNorm
2423
from ..controlnets.controlnet import zero_module
2524
from ..modeling_utils import ModelMixin
26-
from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM
25+
from ..transformers.transformer_z_image import (
26+
SEQ_MULTI_OF,
27+
ZImageTransformer2DModel,
28+
ZImageTransformerBlock,
29+
)
2730

2831

2932
class ZImageControlTransformerBlock(ZImageTransformerBlock):
3033
def __init__(
31-
self,
34+
self,
3235
layer_id: int,
3336
dim: int,
3437
n_heads: int,
3538
n_kv_heads: int,
3639
norm_eps: float,
3740
qk_norm: bool,
3841
modulation=True,
39-
block_id=0
42+
block_id=0,
4043
):
4144
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
4245
self.block_id = block_id
@@ -57,7 +60,8 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs):
5760
all_c += [c_skip, c]
5861
c = torch.stack(all_c)
5962
return c
60-
63+
64+
6165
class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
6266
_supports_gradient_checkpointing = True
6367

@@ -72,7 +76,7 @@ def __init__(
7276
n_kv_heads=30,
7377
norm_eps=1e-5,
7478
qk_norm=True,
75-
control_layers_places: List[int]=None,
79+
control_layers_places: List[int] = None,
7680
control_in_dim=None,
7781
):
7882
super().__init__()
@@ -84,15 +88,7 @@ def __init__(
8488
# control blocks
8589
self.control_layers = nn.ModuleList(
8690
[
87-
ZImageControlTransformerBlock(
88-
i,
89-
dim,
90-
n_heads,
91-
n_kv_heads,
92-
norm_eps,
93-
qk_norm,
94-
block_id=i
95-
)
91+
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
9692
for i in self.control_layers_places
9793
]
9894
)
@@ -425,7 +421,9 @@ def forward(
425421

426422
if torch.is_grad_enabled() and self.gradient_checkpointing:
427423
for layer in self.control_noise_refiner:
428-
control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input)
424+
control_context = self._gradient_checkpointing_func(
425+
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input
426+
)
429427
else:
430428
for layer in self.control_noise_refiner:
431429
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
@@ -440,14 +438,14 @@ def forward(
440438
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
441439
c = control_context_unified
442440

443-
new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input)
444-
441+
new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input}
442+
445443
for layer in self.control_layers:
446444
if torch.is_grad_enabled() and self.gradient_checkpointing:
447445
c = self._gradient_checkpointing_func(layer, c, **new_kwargs)
448446
else:
449447
c = layer(c, **new_kwargs)
450-
448+
451449
hints = torch.unbind(c)[:-1] * conditioning_scale
452450
controlnet_block_samples = {}
453451
for layer_idx in range(self.n_layers):

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def forward(
538538
cap_feats: List[torch.Tensor],
539539
patch_size=2,
540540
f_patch_size=1,
541-
controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None,
541+
controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None,
542542
return_dict: bool = True,
543543
):
544544
assert patch_size in self.all_patch_size

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def retrieve_latents(
8989
raise AttributeError("Could not access latents of provided encoder_output")
9090

9191

92-
9392
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9493
def retrieve_timesteps(
9594
scheduler,
@@ -509,7 +508,7 @@ def __call__(
509508
num_images_per_prompt=num_images_per_prompt,
510509
device=device,
511510
dtype=self.vae.dtype,
512-
)
511+
)
513512
height, width = control_image.shape[-2:]
514513
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
515514
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

0 commit comments

Comments
 (0)