Skip to content

Commit 1b5d91d

Browse files
Merge branch 'main' into feat/z-image-starter-models
2 parents a748519 + 90e3400 commit 1b5d91d

File tree

28 files changed

+2789
-81
lines changed

28 files changed

+2789
-81
lines changed

invokeai/app/api/routers/workflows.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,15 @@ async def get_workflow_thumbnail(
223223
raise HTTPException(status_code=404)
224224

225225

226+
@workflows_router.get("/tags", operation_id="get_all_tags")
227+
async def get_all_tags(
228+
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
229+
) -> list[str]:
230+
"""Gets all unique tags from workflows"""
231+
232+
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
233+
234+
226235
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
227236
async def get_counts_by_tag(
228237
tags: list[str] = Query(description="The tags to get counts for"),
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2+
"""Z-Image Control invocation for spatial conditioning."""
3+
4+
from pydantic import BaseModel, Field
5+
6+
from invokeai.app.invocations.baseinvocation import (
7+
BaseInvocation,
8+
BaseInvocationOutput,
9+
Classification,
10+
invocation,
11+
invocation_output,
12+
)
13+
from invokeai.app.invocations.fields import (
14+
FieldDescriptions,
15+
ImageField,
16+
InputField,
17+
OutputField,
18+
)
19+
from invokeai.app.invocations.model import ModelIdentifierField
20+
from invokeai.app.services.shared.invocation_context import InvocationContext
21+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
22+
23+
24+
class ZImageControlField(BaseModel):
25+
"""A Z-Image control conditioning field for spatial control (Canny, HED, Depth, Pose, MLSD)."""
26+
27+
image_name: str = Field(description="The name of the preprocessed control image")
28+
control_model: ModelIdentifierField = Field(description="The Z-Image ControlNet adapter model")
29+
control_context_scale: float = Field(
30+
default=0.75,
31+
ge=0.0,
32+
le=2.0,
33+
description="The strength of the control signal. Recommended range: 0.65-0.80.",
34+
)
35+
begin_step_percent: float = Field(
36+
default=0.0,
37+
ge=0.0,
38+
le=1.0,
39+
description="When the control is first applied (% of total steps)",
40+
)
41+
end_step_percent: float = Field(
42+
default=1.0,
43+
ge=0.0,
44+
le=1.0,
45+
description="When the control is last applied (% of total steps)",
46+
)
47+
48+
49+
@invocation_output("z_image_control_output")
50+
class ZImageControlOutput(BaseInvocationOutput):
51+
"""Z-Image Control output containing control configuration."""
52+
53+
control: ZImageControlField = OutputField(description="Z-Image control conditioning")
54+
55+
56+
@invocation(
57+
"z_image_control",
58+
title="Z-Image ControlNet",
59+
tags=["image", "z-image", "control", "controlnet"],
60+
category="control",
61+
version="1.1.0",
62+
classification=Classification.Prototype,
63+
)
64+
class ZImageControlInvocation(BaseInvocation):
65+
"""Configure Z-Image ControlNet for spatial conditioning.
66+
67+
Takes a preprocessed control image (e.g., Canny edges, depth map, pose)
68+
and a Z-Image ControlNet adapter model to enable spatial control.
69+
70+
Supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
71+
Recommended control_context_scale: 0.65-0.80.
72+
"""
73+
74+
image: ImageField = InputField(
75+
description="The preprocessed control image (Canny, HED, Depth, Pose, or MLSD)",
76+
)
77+
control_model: ModelIdentifierField = InputField(
78+
description=FieldDescriptions.controlnet_model,
79+
title="Control Model",
80+
ui_model_base=BaseModelType.ZImage,
81+
ui_model_type=ModelType.ControlNet,
82+
)
83+
control_context_scale: float = InputField(
84+
default=0.75,
85+
ge=0.0,
86+
le=2.0,
87+
description="Strength of the control signal. Recommended range: 0.65-0.80.",
88+
title="Control Scale",
89+
)
90+
begin_step_percent: float = InputField(
91+
default=0.0,
92+
ge=0.0,
93+
le=1.0,
94+
description="When the control is first applied (% of total steps)",
95+
)
96+
end_step_percent: float = InputField(
97+
default=1.0,
98+
ge=0.0,
99+
le=1.0,
100+
description="When the control is last applied (% of total steps)",
101+
)
102+
103+
def invoke(self, context: InvocationContext) -> ZImageControlOutput:
104+
return ZImageControlOutput(
105+
control=ZImageControlField(
106+
image_name=self.image.image_name,
107+
control_model=self.control_model,
108+
control_context_scale=self.control_context_scale,
109+
begin_step_percent=self.begin_step_percent,
110+
end_step_percent=self.end_step_percent,
111+
)
112+
)

invokeai/app/invocations/z_image_denoise.py

Lines changed: 168 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import math
12
from contextlib import ExitStack
23
from typing import Callable, Iterator, Optional, Tuple
34

5+
import einops
46
import torch
57
import torchvision.transforms as tv_transforms
8+
from PIL import Image
69
from torchvision.transforms.functional import resize as tv_resize
710
from tqdm import tqdm
811

@@ -16,8 +19,10 @@
1619
LatentsField,
1720
ZImageConditioningField,
1821
)
19-
from invokeai.app.invocations.model import TransformerField
22+
from invokeai.app.invocations.model import TransformerField, VAEField
2023
from invokeai.app.invocations.primitives import LatentsOutput
24+
from invokeai.app.invocations.z_image_control import ZImageControlField
25+
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
2126
from invokeai.app.services.shared.invocation_context import InvocationContext
2227
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
2328
from invokeai.backend.patches.layer_patcher import LayerPatcher
@@ -27,6 +32,11 @@
2732
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2833
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
2934
from invokeai.backend.util.devices import TorchDevice
35+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
36+
from invokeai.backend.z_image.z_image_controlnet_extension import (
37+
ZImageControlNetExtension,
38+
z_image_forward_with_control,
39+
)
3040

3141

3242
@invocation(
@@ -59,18 +69,31 @@ class ZImageDenoiseInvocation(BaseInvocation):
5969
negative_conditioning: Optional[ZImageConditioningField] = InputField(
6070
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
6171
)
62-
# Z-Image-Turbo uses guidance_scale=0.0 by default (no CFG)
72+
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
6373
guidance_scale: float = InputField(
64-
default=0.0,
65-
ge=0.0,
66-
description="Guidance scale for classifier-free guidance. Use 0.0 for Z-Image-Turbo.",
74+
default=1.0,
75+
ge=1.0,
76+
description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
77+
"Values > 1.0 amplify guidance.",
6778
title="Guidance Scale",
6879
)
6980
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
7081
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
7182
# Z-Image-Turbo uses 8 steps by default
7283
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
7384
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
85+
# Z-Image Control support
86+
control: Optional[ZImageControlField] = InputField(
87+
default=None,
88+
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
89+
input=Input.Connection,
90+
)
91+
# VAE for encoding control images (required when using control)
92+
vae: Optional[VAEField] = InputField(
93+
default=None,
94+
description=FieldDescriptions.vae + " Required for control conditioning.",
95+
input=Input.Connection,
96+
)
7497

7598
@torch.no_grad()
7699
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -206,12 +229,17 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
206229
device=device,
207230
)
208231

209-
# Load negative conditioning if provided and guidance_scale > 0
232+
# Load negative conditioning if provided and guidance_scale != 1.0
233+
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
234+
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
235+
# This matches FLUX's convention where 1.0 means "no CFG"
210236
neg_prompt_embeds: torch.Tensor | None = None
211-
do_classifier_free_guidance = self.guidance_scale > 0.0 and self.negative_conditioning is not None
237+
do_classifier_free_guidance = (
238+
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
239+
)
212240
if do_classifier_free_guidance:
213241
if self.negative_conditioning is None:
214-
raise ValueError("Negative conditioning is required when guidance_scale > 0")
242+
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
215243
neg_prompt_embeds = self._load_text_conditioning(
216244
context=context,
217245
conditioning_name=self.negative_conditioning.conditioning_name,
@@ -293,9 +321,6 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
293321
)
294322

295323
with ExitStack() as exit_stack:
296-
# Load transformer and apply LoRA patches
297-
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
298-
299324
# Get transformer config to determine if it's quantized
300325
transformer_config = context.models.get_config(self.transformer.transformer)
301326

@@ -309,6 +334,102 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
309334
else:
310335
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
311336

337+
# Load transformer - always use base transformer, control is handled via extension
338+
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
339+
340+
# Prepare control extension if control is provided
341+
control_extension: ZImageControlNetExtension | None = None
342+
343+
if self.control is not None:
344+
# Load control adapter using context manager (proper GPU memory management)
345+
control_model_info = context.models.load(self.control.control_model)
346+
(_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
347+
assert isinstance(control_adapter, ZImageControlAdapter)
348+
349+
# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
350+
adapter_config = control_adapter.config
351+
control_in_dim = adapter_config.get("control_in_dim", 16)
352+
num_control_blocks = adapter_config.get("num_control_blocks", 6)
353+
354+
# Log control configuration for debugging
355+
version = "V2.0" if control_in_dim > 16 else "V1"
356+
context.util.signal_progress(
357+
f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
358+
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
359+
)
360+
361+
# Load and prepare control image - must be VAE-encoded!
362+
if self.vae is None:
363+
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
364+
365+
control_image = context.images.get_pil(self.control.image_name)
366+
367+
# Resize control image to match output dimensions
368+
control_image = control_image.convert("RGB")
369+
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
370+
371+
# Convert to tensor format for VAE encoding
372+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
373+
374+
control_image_tensor = image_resized_to_grid_as_tensor(control_image)
375+
if control_image_tensor.dim() == 3:
376+
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
377+
378+
# Encode control image through VAE to get latents
379+
vae_info = context.models.load(self.vae.vae)
380+
control_latents = ZImageImageToLatentsInvocation.vae_encode(
381+
vae_info=vae_info,
382+
image_tensor=control_image_tensor,
383+
)
384+
385+
# Move to inference device/dtype
386+
control_latents = control_latents.to(device=device, dtype=inference_dtype)
387+
388+
# Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
389+
control_latents = control_latents.squeeze(0).unsqueeze(1)
390+
391+
# Prepare control_cond based on control_in_dim
392+
# V1: 16 channels (just control latents)
393+
# V2.0: 33 channels = 16 control + 16 reference + 1 mask
394+
# - Channels 0-15: control image latents (from VAE encoding)
395+
# - Channels 16-31: reference/inpaint image latents (zeros for pure control)
396+
# - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
397+
# For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
398+
c, f, h, w = control_latents.shape
399+
if c < control_in_dim:
400+
padding_channels = control_in_dim - c
401+
if padding_channels == 17:
402+
# V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
403+
ref_padding = torch.zeros(
404+
(16, f, h, w),
405+
device=device,
406+
dtype=inference_dtype,
407+
)
408+
# Mask channel = 1.0 means "don't inpaint this region, use control signal"
409+
mask_channel = torch.ones(
410+
(1, f, h, w),
411+
device=device,
412+
dtype=inference_dtype,
413+
)
414+
control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
415+
else:
416+
# Generic padding with zeros for other cases
417+
zero_padding = torch.zeros(
418+
(padding_channels, f, h, w),
419+
device=device,
420+
dtype=inference_dtype,
421+
)
422+
control_latents = torch.cat([control_latents, zero_padding], dim=0)
423+
424+
# Create control extension (adapter is already on device from model_on_device)
425+
control_extension = ZImageControlNetExtension(
426+
control_adapter=control_adapter,
427+
control_cond=control_latents,
428+
weight=self.control.control_context_scale,
429+
begin_step_percent=self.control.begin_step_percent,
430+
end_step_percent=self.control.end_step_percent,
431+
)
432+
312433
# Apply LoRA models to the transformer.
313434
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
314435
exit_stack.enter_context(
@@ -340,25 +461,48 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
340461
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
341462
latent_model_input_list = list(latent_model_input.unbind(dim=0))
342463

343-
# Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
344-
model_output = transformer(
345-
x=latent_model_input_list,
346-
t=timestep,
347-
cap_feats=[pos_prompt_embeds],
348-
)
349-
model_out_list = model_output[0] # Extract list of tensors from tuple
464+
# Determine if control should be applied at this step
465+
apply_control = control_extension is not None and control_extension.should_apply(step_idx, total_steps)
466+
467+
# Run forward pass - use custom forward with control if extension is active
468+
if apply_control:
469+
model_out_list, _ = z_image_forward_with_control(
470+
transformer=transformer,
471+
x=latent_model_input_list,
472+
t=timestep,
473+
cap_feats=[pos_prompt_embeds],
474+
control_extension=control_extension,
475+
)
476+
else:
477+
model_output = transformer(
478+
x=latent_model_input_list,
479+
t=timestep,
480+
cap_feats=[pos_prompt_embeds],
481+
)
482+
model_out_list = model_output[0] # Extract list of tensors from tuple
483+
350484
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
351485
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
352486
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
353487

354488
# Apply CFG if enabled
355489
if do_classifier_free_guidance and neg_prompt_embeds is not None:
356-
model_output_uncond = transformer(
357-
x=latent_model_input_list,
358-
t=timestep,
359-
cap_feats=[neg_prompt_embeds],
360-
)
361-
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
490+
if apply_control:
491+
model_out_list_uncond, _ = z_image_forward_with_control(
492+
transformer=transformer,
493+
x=latent_model_input_list,
494+
t=timestep,
495+
cap_feats=[neg_prompt_embeds],
496+
control_extension=control_extension,
497+
)
498+
else:
499+
model_output_uncond = transformer(
500+
x=latent_model_input_list,
501+
t=timestep,
502+
cap_feats=[neg_prompt_embeds],
503+
)
504+
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
505+
362506
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
363507
noise_pred_uncond = noise_pred_uncond.squeeze(2)
364508
noise_pred_uncond = -noise_pred_uncond

0 commit comments

Comments
 (0)