Skip to content

Commit bde3acc

Browse files
Add Z-Image ControlNet V2.0 support
VRAM usage is high. - Auto-detect control_in_dim from adapter weights (16 for V1, 33 for V2.0) - Auto-detect n_refiner_layers from state dict - Add zero-padding for V2.0's additional channels - Use accelerate.init_empty_weights() for efficient model creation - Add ControlNet_Checkpoint_ZImage_Config to frontend schema
1 parent e211ac9 commit bde3acc

File tree

3 files changed

+327
-59
lines changed

3 files changed

+327
-59
lines changed

invokeai/app/invocations/z_image_denoise.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -333,57 +333,74 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
333333
begin_step_percent = 0.0
334334
end_step_percent = 1.0
335335
cached_weights = None
336+
control_in_dim = 16 # Default, will be set from adapter config if control is used
336337

337338
if self.control is not None:
338-
# Simplified approach: Create model and load weights directly
339-
# 1. Get base transformer config (stays on CPU)
340-
# 2. Create ZImageControlTransformer2DModel
341-
# 3. Load base weights directly (strict=False)
342-
# 4. Load adapter control weights on top
343-
# 5. Move to GPU
344-
345-
# Get base transformer config (stays on CPU)
339+
# Load base transformer config (NOT to GPU yet - just get the model reference)
346340
base_transformer = cast(ZImageTransformer2DModel, transformer_info.model)
347341
base_config = base_transformer.config
348342

349-
# Load control adapter (stays on CPU)
343+
# Load control adapter
350344
control_model_info = context.models.load(self.control.control_model)
351345
control_adapter = control_model_info.model
352346
assert isinstance(control_adapter, ZImageControlAdapter)
353347

354-
# Create ZImageControlTransformer2DModel
355-
control_transformer = ZImageControlTransformer2DModel(
356-
all_patch_size=base_config.all_patch_size,
357-
all_f_patch_size=base_config.all_f_patch_size,
358-
in_channels=base_config.in_channels,
359-
dim=base_config.dim,
360-
n_layers=base_config.n_layers,
361-
n_refiner_layers=base_config.n_refiner_layers,
362-
n_heads=base_config.n_heads,
363-
n_kv_heads=base_config.n_kv_heads,
364-
norm_eps=base_config.norm_eps,
365-
qk_norm=base_config.qk_norm,
366-
cap_feat_dim=base_config.cap_feat_dim,
367-
rope_theta=base_config.rope_theta,
368-
t_scale=base_config.t_scale,
369-
axes_dims=base_config.axes_dims,
370-
axes_lens=base_config.axes_lens,
348+
# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
349+
adapter_config = control_adapter.config
350+
control_in_dim = adapter_config.get("control_in_dim", 16)
351+
num_control_blocks = adapter_config.get("num_control_blocks", 6)
352+
n_refiner_layers = adapter_config.get("n_refiner_layers", 2)
353+
354+
# Calculate control_layers_places based on num_control_blocks
355+
control_layers_places = [i * 2 for i in range(num_control_blocks)]
356+
357+
# Log control configuration for debugging
358+
version = "V2.0" if control_in_dim > 16 else "V1"
359+
context.util.signal_progress(
360+
f"Using Z-Image ControlNet {version}: control_in_dim={control_in_dim}, "
361+
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
371362
)
372363

373-
# Load base transformer weights directly (strict=False handles missing control keys)
364+
# Create control transformer structure with empty weights
365+
import accelerate
366+
367+
with accelerate.init_empty_weights():
368+
control_transformer = ZImageControlTransformer2DModel(
369+
control_layers_places=control_layers_places,
370+
control_in_dim=control_in_dim,
371+
all_patch_size=base_config.all_patch_size,
372+
all_f_patch_size=base_config.all_f_patch_size,
373+
in_channels=base_config.in_channels,
374+
dim=base_config.dim,
375+
n_layers=base_config.n_layers,
376+
n_refiner_layers=n_refiner_layers,
377+
n_heads=base_config.n_heads,
378+
n_kv_heads=base_config.n_kv_heads,
379+
norm_eps=base_config.norm_eps,
380+
qk_norm=base_config.qk_norm,
381+
cap_feat_dim=base_config.cap_feat_dim,
382+
rope_theta=base_config.rope_theta,
383+
t_scale=base_config.t_scale,
384+
axes_dims=base_config.axes_dims,
385+
axes_lens=base_config.axes_lens,
386+
)
387+
388+
# Load base weights with assign=True (assigns tensors directly, no copy of data)
374389
control_transformer.load_state_dict(base_transformer.state_dict(), strict=False, assign=True)
375390

376-
# Load control adapter weights on top (only control-specific keys)
377-
# Filter to only control_ prefixed keys to avoid overwriting x_pad_token
378-
adapter_control_weights = {
379-
k: v for k, v in control_adapter.state_dict().items() if k.startswith("control_")
380-
}
381-
control_transformer.load_state_dict(adapter_control_weights, strict=False, assign=True)
391+
# Load control adapter weights on top
392+
adapter_weights = {k: v for k, v in control_adapter.state_dict().items() if k.startswith("control_")}
393+
control_transformer.load_state_dict(adapter_weights, strict=False, assign=True)
382394

383-
# Move to device
395+
# Move combined model to device
384396
control_transformer = control_transformer.to(device=device, dtype=inference_dtype)
385397
active_transformer = control_transformer
386398

399+
# Clean up to save memory # need to check
400+
#del control_adapter
401+
#if torch.cuda.is_available():
402+
# torch.cuda.empty_cache()
403+
387404
# Load and prepare control image - must be VAE-encoded!
388405
if self.vae is None:
389406
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
@@ -413,8 +430,23 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
413430

414431
# Add frame dimension: [B, C, H, W] -> [B, C, 1, H, W]
415432
control_latents = control_latents.unsqueeze(2)
416-
# Convert to list format expected by transformer
433+
434+
# Prepare control_context based on control_in_dim
435+
# V1: 16 channels (just control latents)
436+
# V2.0: 33 channels (control latents + zero padding)
437+
# Following diffusers approach: simple zero-padding to match control_in_dim
438+
b, c, f, h, w = control_latents.shape
439+
if c < control_in_dim:
440+
# Pad with zeros to match control_in_dim (diffusers approach)
441+
padding_channels = control_in_dim - c
442+
zero_padding = torch.zeros(
443+
(b, padding_channels, f, h, w),
444+
device=device,
445+
dtype=inference_dtype,
446+
)
447+
control_latents = torch.cat([control_latents, zero_padding], dim=1)
417448
control_context = list(control_latents.unbind(dim=0))
449+
418450
control_context_scale = self.control.control_context_scale
419451
begin_step_percent = self.control.begin_step_percent
420452
end_step_percent = self.control.end_step_percent
@@ -425,7 +457,8 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
425457

426458
# Apply LoRA models to the active transformer.
427459
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
428-
# cached_weights is None when using control (since we create a new model), otherwise it's from model_on_device
460+
# cached_weights is None when using control (since we create a new combined model),
461+
# otherwise it comes from model_on_device() context.
429462
exit_stack.enter_context(
430463
LayerPatcher.apply_smart_model_patches(
431464
model=active_transformer,
@@ -457,15 +490,16 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
457490

458491
# Determine if control should be applied at this step
459492
step_percent = step_idx / total_steps
493+
use_control = self.control is not None
460494
apply_control = (
461-
control_context is not None
495+
use_control
462496
and step_percent >= begin_step_percent
463497
and step_percent <= end_step_percent
464498
)
465499

466500
# Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
467501
# If control is active, pass control_context to the control transformer
468-
if apply_control:
502+
if apply_control and control_context is not None:
469503
model_output = active_transformer(
470504
x=latent_model_input_list,
471505
t=timestep,
@@ -486,7 +520,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
486520

487521
# Apply CFG if enabled
488522
if do_classifier_free_guidance and neg_prompt_embeds is not None:
489-
if apply_control:
523+
if apply_control and control_context is not None:
490524
model_output_uncond = active_transformer(
491525
x=latent_model_input_list,
492526
t=timestep,

invokeai/backend/model_manager/load/model_loaders/z_image.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,16 +420,46 @@ def _load_control_adapter(
420420
control_block_indices.add(int(parts[1]))
421421
num_control_blocks = len(control_block_indices) if control_block_indices else 6
422422

423+
# Determine number of refiner layers from state dict
424+
refiner_indices: set[int] = set()
425+
for key in sd.keys():
426+
if key.startswith("control_noise_refiner."):
427+
parts = key.split(".")
428+
if len(parts) > 1 and parts[1].isdigit():
429+
refiner_indices.add(int(parts[1]))
430+
n_refiner_layers = len(refiner_indices) if refiner_indices else 2
431+
432+
# Determine control_in_dim from embedder weight shape
433+
# control_in_dim = weight.shape[1] / (f_patch_size * patch_size * patch_size)
434+
# For patch_size=2, f_patch_size=1: control_in_dim = weight.shape[1] / 4
435+
control_in_dim = 16 # Default for V1
436+
embedder_key = "control_all_x_embedder.2-1.weight"
437+
if embedder_key in sd:
438+
weight_shape = sd[embedder_key].shape
439+
# weight_shape[1] = f_patch_size * patch_size * patch_size * control_in_dim
440+
control_in_dim = weight_shape[1] // 4 # 4 = 1 * 2 * 2
441+
442+
# Log detected configuration for debugging
443+
from invokeai.backend.util.logging import InvokeAILogger
444+
445+
logger = InvokeAILogger.get_logger(self.__class__.__name__)
446+
version = "V2.0" if control_in_dim > 16 else "V1"
447+
logger.info(
448+
f"Z-Image ControlNet detected: {version} "
449+
f"(control_in_dim={control_in_dim}, num_control_blocks={num_control_blocks}, "
450+
f"n_refiner_layers={n_refiner_layers})"
451+
)
452+
423453
# Create an empty control adapter
424454
dim = 3840
425455
with accelerate.init_empty_weights():
426456
model = ZImageControlAdapter(
427457
num_control_blocks=num_control_blocks,
428-
control_in_dim=16,
458+
control_in_dim=control_in_dim,
429459
all_patch_size=(2,),
430460
all_f_patch_size=(1,),
431461
dim=dim,
432-
n_refiner_layers=2,
462+
n_refiner_layers=n_refiner_layers,
433463
n_heads=30,
434464
n_kv_heads=30,
435465
norm_eps=1e-05,

0 commit comments

Comments
 (0)