Skip to content

Commit 3ed0e55

Browse files
fix: resolve linting errors in Z-Image ControlNet support
- Add missing ControlNet_Checkpoint_ZImage_Config import - Remove unused imports (Any, Dict, ADALN_EMBED_DIM, is_torch_version) - Add strict=True to zip() calls - Replace mutable list defaults with immutable tuples - Replace dict() calls with literal syntax - Sort imports in z_image_denoise.py
1 parent 8db8aa8 commit 3ed0e55

File tree

4 files changed

+22
-24
lines changed

4 files changed

+22
-24
lines changed

invokeai/app/invocations/z_image_denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
LatentsField,
2020
ZImageConditioningField,
2121
)
22-
from invokeai.app.invocations.z_image_control import ZImageControlField
23-
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
2422
from invokeai.app.invocations.model import TransformerField, VAEField
2523
from invokeai.app.invocations.primitives import LatentsOutput
24+
from invokeai.app.invocations.z_image_control import ZImageControlField
25+
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
2626
from invokeai.app.services.shared.invocation_context import InvocationContext
2727
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
2828
from invokeai.backend.patches.layer_patcher import LayerPatcher

invokeai/backend/model_manager/load/model_loaders/z_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import AutoTokenizer, Qwen3ForCausalLM
1010

1111
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
12+
from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_ZImage_Config
1213
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
1314
from invokeai.backend.model_manager.configs.main import Main_Checkpoint_ZImage_Config, Main_GGUF_ZImage_Config
1415
from invokeai.backend.model_manager.configs.qwen3_encoder import (

invokeai/backend/z_image/z_image_control_adapter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,16 @@
1010
control-specific layers (control_layers, control_all_x_embedder, control_noise_refiner).
1111
"""
1212

13-
from typing import Any, Dict, List, Optional
13+
from typing import List, Optional
1414

1515
import torch
1616
import torch.nn as nn
1717
from diffusers.configuration_utils import ConfigMixin, register_to_config
1818
from diffusers.models.modeling_utils import ModelMixin
1919
from diffusers.models.transformers.transformer_z_image import (
20-
ADALN_EMBED_DIM,
2120
SEQ_MULTI_OF,
2221
ZImageTransformerBlock,
2322
)
24-
from diffusers.utils import is_torch_version
2523
from torch.nn.utils.rnn import pad_sequence
2624

2725

@@ -105,7 +103,7 @@ def __init__(
105103

106104
# Control patch embeddings
107105
all_x_embedder = {}
108-
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
106+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True):
109107
x_embedder = nn.Linear(
110108
f_patch_size * patch_size * patch_size * control_in_dim,
111109
dim,

invokeai/backend/z_image/z_image_control_transformer.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch.nn as nn
1616
from diffusers.configuration_utils import register_to_config
1717
from diffusers.models.transformers.transformer_z_image import (
18-
ADALN_EMBED_DIM,
1918
SEQ_MULTI_OF,
2019
ZImageTransformer2DModel,
2120
ZImageTransformerBlock,
@@ -151,8 +150,8 @@ def __init__(
151150
cap_feat_dim: int = 2560,
152151
rope_theta: float = 256.0,
153152
t_scale: float = 1000.0,
154-
axes_dims: List[int] = [32, 48, 48],
155-
axes_lens: List[int] = [1024, 512, 512],
153+
axes_dims: tuple[int, ...] = (32, 48, 48),
154+
axes_lens: tuple[int, ...] = (1024, 512, 512),
156155
):
157156
super().__init__(
158157
all_patch_size=all_patch_size,
@@ -174,7 +173,7 @@ def __init__(
174173

175174
# Control layer configuration
176175
self.control_layers_places = (
177-
[i for i in range(0, n_layers, 2)] if control_layers_places is None else control_layers_places
176+
list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
178177
)
179178
self.control_in_dim = in_channels if control_in_dim is None else control_in_dim
180179

@@ -216,7 +215,7 @@ def __init__(
216215

217216
# Control patch embeddings
218217
all_x_embedder = {}
219-
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
218+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True):
220219
x_embedder = nn.Linear(
221220
f_patch_size * patch_size * patch_size * self.control_in_dim,
222221
dim,
@@ -585,7 +584,7 @@ def custom_forward(*inputs):
585584
cap_len = cap_item_seqlens[i]
586585
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
587586
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
588-
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
587+
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=True)]
589588
unified_max_item_seqlen = max(unified_item_seqlens)
590589

591590
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
@@ -595,11 +594,11 @@ def custom_forward(*inputs):
595594
unified_attn_mask[i, :seq_len] = 1
596595

597596
# Generate control hints
598-
kwargs = dict(
599-
attn_mask=unified_attn_mask,
600-
freqs_cis=unified_freqs_cis,
601-
adaln_input=adaln_input,
602-
)
597+
kwargs = {
598+
"attn_mask": unified_attn_mask,
599+
"freqs_cis": unified_freqs_cis,
600+
"adaln_input": adaln_input,
601+
}
603602
hints = self.forward_control(
604603
unified,
605604
cap_feats,
@@ -612,13 +611,13 @@ def custom_forward(*inputs):
612611

613612
# Main transformer with control hints
614613
for layer in self.layers:
615-
layer_kwargs = dict(
616-
attn_mask=unified_attn_mask,
617-
freqs_cis=unified_freqs_cis,
618-
adaln_input=adaln_input,
619-
hints=hints,
620-
context_scale=control_context_scale,
621-
)
614+
layer_kwargs = {
615+
"attn_mask": unified_attn_mask,
616+
"freqs_cis": unified_freqs_cis,
617+
"adaln_input": adaln_input,
618+
"hints": hints,
619+
"context_scale": control_context_scale,
620+
}
622621
if torch.is_grad_enabled() and self.gradient_checkpointing:
623622

624623
def create_custom_forward(module, **static_kwargs):

0 commit comments

Comments
 (0)