Skip to content

Commit 78d0e1d

Browse files
WIP not working.
feat: Add Z-Image ControlNet support with spatial conditioning Add comprehensive ControlNet support for Z-Image models including: Backend: - New ControlNet_Checkpoint_ZImage_Config for Z-Image control adapter models - Z-Image control key detection (_has_z_image_control_keys) to identify control layers - ZImageControlAdapter loader for standalone control models - ZImageControlTransformer2DModel combining base transformer with control layers - Memory-efficient model loading by building combined state dict
1 parent 2802029 commit 78d0e1d

File tree

18 files changed

+1673
-68
lines changed

18 files changed

+1673
-68
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2+
"""Z-Image Control invocation for spatial conditioning."""
3+
4+
from pydantic import BaseModel, Field
5+
6+
from invokeai.app.invocations.baseinvocation import (
7+
BaseInvocation,
8+
BaseInvocationOutput,
9+
Classification,
10+
invocation,
11+
invocation_output,
12+
)
13+
from invokeai.app.invocations.fields import (
14+
FieldDescriptions,
15+
ImageField,
16+
InputField,
17+
OutputField,
18+
)
19+
from invokeai.app.invocations.model import ModelIdentifierField
20+
from invokeai.app.services.shared.invocation_context import InvocationContext
21+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
22+
23+
24+
class ZImageControlField(BaseModel):
25+
"""A Z-Image control conditioning field for spatial control (Canny, HED, Depth, Pose, MLSD)."""
26+
27+
image_name: str = Field(description="The name of the preprocessed control image")
28+
control_model: ModelIdentifierField = Field(description="The Z-Image ControlNet adapter model")
29+
control_context_scale: float = Field(
30+
default=0.75,
31+
ge=0.0,
32+
le=2.0,
33+
description="The strength of the control signal. Recommended range: 0.65-0.80.",
34+
)
35+
begin_step_percent: float = Field(
36+
default=0.0,
37+
ge=0.0,
38+
le=1.0,
39+
description="When the control is first applied (% of total steps)",
40+
)
41+
end_step_percent: float = Field(
42+
default=1.0,
43+
ge=0.0,
44+
le=1.0,
45+
description="When the control is last applied (% of total steps)",
46+
)
47+
48+
49+
@invocation_output("z_image_control_output")
50+
class ZImageControlOutput(BaseInvocationOutput):
51+
"""Z-Image Control output containing control configuration."""
52+
53+
control: ZImageControlField = OutputField(description="Z-Image control conditioning")
54+
55+
56+
@invocation(
57+
"z_image_control",
58+
title="Z-Image ControlNet",
59+
tags=["image", "z-image", "control", "controlnet"],
60+
category="control",
61+
version="1.1.0",
62+
classification=Classification.Prototype,
63+
)
64+
class ZImageControlInvocation(BaseInvocation):
65+
"""Configure Z-Image ControlNet for spatial conditioning.
66+
67+
Takes a preprocessed control image (e.g., Canny edges, depth map, pose)
68+
and a Z-Image ControlNet adapter model to enable spatial control.
69+
70+
Supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
71+
Recommended control_context_scale: 0.65-0.80.
72+
"""
73+
74+
image: ImageField = InputField(
75+
description="The preprocessed control image (Canny, HED, Depth, Pose, or MLSD)",
76+
)
77+
control_model: ModelIdentifierField = InputField(
78+
description=FieldDescriptions.controlnet_model,
79+
title="Control Model",
80+
ui_model_base=BaseModelType.ZImage,
81+
ui_model_type=ModelType.ControlNet,
82+
)
83+
control_context_scale: float = InputField(
84+
default=0.75,
85+
ge=0.0,
86+
le=2.0,
87+
description="Strength of the control signal. Recommended range: 0.65-0.80.",
88+
title="Control Scale",
89+
)
90+
begin_step_percent: float = InputField(
91+
default=0.0,
92+
ge=0.0,
93+
le=1.0,
94+
description="When the control is first applied (% of total steps)",
95+
)
96+
end_step_percent: float = InputField(
97+
default=1.0,
98+
ge=0.0,
99+
le=1.0,
100+
description="When the control is last applied (% of total steps)",
101+
)
102+
103+
def invoke(self, context: InvocationContext) -> ZImageControlOutput:
104+
return ZImageControlOutput(
105+
control=ZImageControlField(
106+
image_name=self.image.image_name,
107+
control_model=self.control_model,
108+
control_context_scale=self.control_context_scale,
109+
begin_step_percent=self.begin_step_percent,
110+
end_step_percent=self.end_step_percent,
111+
)
112+
)

invokeai/app/invocations/z_image_denoise.py

Lines changed: 159 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from contextlib import ExitStack
2-
from typing import Callable, Iterator, Optional, Tuple
2+
from typing import Callable, Iterator, Optional, Tuple, cast
33

4+
import einops
45
import torch
56
import torchvision.transforms as tv_transforms
7+
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
8+
from PIL import Image
69
from torchvision.transforms.functional import resize as tv_resize
710
from tqdm import tqdm
811

@@ -18,7 +21,9 @@
1821
WithMetadata,
1922
ZImageConditioningField,
2023
)
21-
from invokeai.app.invocations.model import LoRAField, TransformerField
24+
from invokeai.app.invocations.z_image_control import ZImageControlField
25+
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
26+
from invokeai.app.invocations.model import LoRAField, TransformerField, VAEField
2227
from invokeai.app.invocations.primitives import LatentsOutput
2328
from invokeai.app.services.shared.invocation_context import InvocationContext
2429
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
@@ -29,6 +34,8 @@
2934
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
3035
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
3136
from invokeai.backend.util.devices import TorchDevice
37+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
38+
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
3239

3340

3441
@invocation(
@@ -73,6 +80,18 @@ class ZImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
7380
# Z-Image-Turbo uses 8 steps by default
7481
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
7582
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
83+
# Z-Image Control support
84+
control: Optional[ZImageControlField] = InputField(
85+
default=None,
86+
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
87+
input=Input.Connection,
88+
)
89+
# VAE for encoding control images (required when using control)
90+
vae: Optional[VAEField] = InputField(
91+
default=None,
92+
description=FieldDescriptions.vae + " Required for control conditioning.",
93+
input=Input.Connection,
94+
)
7695

7796
@torch.no_grad()
7897
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -283,27 +302,121 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
283302
)
284303

285304
with ExitStack() as exit_stack:
286-
# Load transformer and apply LoRA patches
287-
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
288-
289305
# Get transformer config to determine if it's quantized
290306
transformer_config = context.models.get_config(self.transformer.transformer)
291307

292308
# Determine if the model is quantized.
293309
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
294310
# slower inference than direct patching, but is agnostic to the quantization format.
295-
if transformer_config.format in [ModelFormat.Diffusers]:
311+
if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
296312
model_is_quantized = False
297313
elif transformer_config.format in [ModelFormat.GGUFQuantized]:
298314
model_is_quantized = True
299315
else:
300316
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
301317

302-
# Apply LoRA models to the transformer.
318+
# Load control adapter and prepare combined transformer if control is provided
319+
control_context: list[torch.Tensor] | None = None
320+
control_context_scale = 0.75
321+
begin_step_percent = 0.0
322+
end_step_percent = 1.0
323+
cached_weights = None
324+
325+
if self.control is not None:
326+
# Simplified approach: Create model and load weights directly
327+
# 1. Get base transformer config (stays on CPU)
328+
# 2. Create ZImageControlTransformer2DModel
329+
# 3. Load base weights directly (strict=False)
330+
# 4. Load adapter control weights on top
331+
# 5. Move to GPU
332+
333+
# Get base transformer config (stays on CPU)
334+
base_transformer = cast(ZImageTransformer2DModel, transformer_info.model)
335+
base_config = base_transformer.config
336+
337+
# Load control adapter (stays on CPU)
338+
control_model_info = context.models.load(self.control.control_model)
339+
control_adapter = control_model_info.model
340+
assert isinstance(control_adapter, ZImageControlAdapter)
341+
342+
# Create ZImageControlTransformer2DModel
343+
control_transformer = ZImageControlTransformer2DModel(
344+
all_patch_size=base_config.all_patch_size,
345+
all_f_patch_size=base_config.all_f_patch_size,
346+
in_channels=base_config.in_channels,
347+
dim=base_config.dim,
348+
n_layers=base_config.n_layers,
349+
n_refiner_layers=base_config.n_refiner_layers,
350+
n_heads=base_config.n_heads,
351+
n_kv_heads=base_config.n_kv_heads,
352+
norm_eps=base_config.norm_eps,
353+
qk_norm=base_config.qk_norm,
354+
cap_feat_dim=base_config.cap_feat_dim,
355+
rope_theta=base_config.rope_theta,
356+
t_scale=base_config.t_scale,
357+
axes_dims=base_config.axes_dims,
358+
axes_lens=base_config.axes_lens,
359+
)
360+
361+
# Load base transformer weights directly (strict=False handles missing control keys)
362+
control_transformer.load_state_dict(base_transformer.state_dict(), strict=False, assign=True)
363+
364+
# Load control adapter weights on top (only control-specific keys)
365+
# Filter to only control_ prefixed keys to avoid overwriting x_pad_token
366+
adapter_control_weights = {
367+
k: v for k, v in control_adapter.state_dict().items() if k.startswith("control_")
368+
}
369+
control_transformer.load_state_dict(adapter_control_weights, strict=False, assign=True)
370+
371+
# Move to device
372+
control_transformer = control_transformer.to(device=device, dtype=inference_dtype)
373+
active_transformer = control_transformer
374+
375+
# Load and prepare control image - must be VAE-encoded!
376+
if self.vae is None:
377+
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
378+
379+
control_image = context.images.get_pil(self.control.image_name)
380+
381+
# Resize control image to match output dimensions
382+
control_image = control_image.convert("RGB")
383+
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
384+
385+
# Convert to tensor format for VAE encoding
386+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
387+
388+
control_image_tensor = image_resized_to_grid_as_tensor(control_image)
389+
if control_image_tensor.dim() == 3:
390+
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
391+
392+
# Encode control image through VAE to get latents
393+
vae_info = context.models.load(self.vae.vae)
394+
control_latents = ZImageImageToLatentsInvocation.vae_encode(
395+
vae_info=vae_info,
396+
image_tensor=control_image_tensor,
397+
)
398+
399+
# Move to inference device/dtype
400+
control_latents = control_latents.to(device=device, dtype=inference_dtype)
401+
402+
# Add frame dimension: [B, C, H, W] -> [B, C, 1, H, W]
403+
control_latents = control_latents.unsqueeze(2)
404+
# Convert to list format expected by transformer
405+
control_context = list(control_latents.unbind(dim=0))
406+
control_context_scale = self.control.control_context_scale
407+
begin_step_percent = self.control.begin_step_percent
408+
end_step_percent = self.control.end_step_percent
409+
else:
410+
# No control - load transformer normally
411+
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
412+
active_transformer = transformer
413+
414+
# Apply LoRA models to the active transformer.
303415
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
416+
# cached_weights is None when using control (since we create a new model), otherwise it's from model_on_device
304417
exit_stack.enter_context(
305418
LayerPatcher.apply_smart_model_patches(
306-
model=transformer,
419+
model=active_transformer,
307420
patches=self._lora_iterator(context),
308421
prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
309422
dtype=inference_dtype,
@@ -326,28 +439,55 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
326439
# Run transformer for positive prediction
327440
# Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
328441
# Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
329-
latent_model_input = latents.to(transformer.dtype)
442+
latent_model_input = latents.to(active_transformer.dtype)
330443
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
331444
latent_model_input_list = list(latent_model_input.unbind(dim=0))
332445

333-
# Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
334-
model_output = transformer(
335-
x=latent_model_input_list,
336-
t=timestep,
337-
cap_feats=[pos_prompt_embeds],
446+
# Determine if control should be applied at this step
447+
step_percent = step_idx / total_steps
448+
apply_control = (
449+
control_context is not None
450+
and step_percent >= begin_step_percent
451+
and step_percent <= end_step_percent
338452
)
453+
454+
# Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
455+
# If control is active, pass control_context to the control transformer
456+
if apply_control:
457+
model_output = active_transformer(
458+
x=latent_model_input_list,
459+
t=timestep,
460+
cap_feats=[pos_prompt_embeds],
461+
control_context=control_context,
462+
control_context_scale=control_context_scale,
463+
)
464+
else:
465+
model_output = active_transformer(
466+
x=latent_model_input_list,
467+
t=timestep,
468+
cap_feats=[pos_prompt_embeds],
469+
)
339470
model_out_list = model_output[0] # Extract list of tensors from tuple
340471
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
341472
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
342473
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
343474

344475
# Apply CFG if enabled
345476
if do_classifier_free_guidance and neg_prompt_embeds is not None:
346-
model_output_uncond = transformer(
347-
x=latent_model_input_list,
348-
t=timestep,
349-
cap_feats=[neg_prompt_embeds],
350-
)
477+
if apply_control:
478+
model_output_uncond = active_transformer(
479+
x=latent_model_input_list,
480+
t=timestep,
481+
cap_feats=[neg_prompt_embeds],
482+
control_context=control_context,
483+
control_context_scale=control_context_scale,
484+
)
485+
else:
486+
model_output_uncond = active_transformer(
487+
x=latent_model_input_list,
488+
t=timestep,
489+
cap_feats=[neg_prompt_embeds],
490+
)
351491
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
352492
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
353493
noise_pred_uncond = noise_pred_uncond.squeeze(2)

0 commit comments

Comments
 (0)