Skip to content

Commit 65efc3d

Browse files
Feature: Add Z-Image-Turbo regional guidance (#8672)
* feat: Add Regional Guidance support for Z-Image model Implements regional prompting for Z-Image (S3-DiT Transformer) allowing different prompts to affect different image regions using attention masks. Backend changes: - Add ZImageRegionalPromptingExtension for mask preparation - Add ZImageTextConditioning and ZImageRegionalTextConditioning data classes - Patch transformer forward to inject 4D regional attention masks - Use additive float mask (0.0 attend, -inf block) in bfloat16 for compatibility - Alternate regional/full attention layers for global coherence Frontend changes: - Update buildZImageGraph to support regional conditioning collectors - Update addRegions to create z_image_text_encoder nodes for regions - Update addZImageLoRAs to handle optional negCond when guidance_scale=0 - Add Z-Image validation (no IP adapters, no autoNegative) * @Pfannkuchensack Fix windows path again * ruff check fix * ruff formating * fix(ui): Z-Image CFG guidance_scale check uses > 1 instead of > 0 Changed the guidance_scale check from > 0 to > 1 for Z-Image models. Since Z-Image uses guidance_scale=1.0 as "no CFG" (matching FLUX convention), negative conditioning should only be created when guidance_scale > 1. --------- Co-authored-by: Lincoln Stein <[email protected]>
1 parent de1aa55 commit 65efc3d

File tree

14 files changed

+24511
-5347
lines changed

14 files changed

+24511
-5347
lines changed

invokeai/app/invocations/fields.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ class ZImageConditioningField(BaseModel):
333333
"""A Z-Image conditioning tensor primitive value"""
334334

335335
conditioning_name: str = Field(description="The name of conditioning tensor")
336+
mask: Optional[TensorField] = Field(
337+
default=None,
338+
description="The mask associated with this conditioning tensor for regional prompting. "
339+
"Excluded regions should be set to False, included regions should be set to True.",
340+
)
336341

337342

338343
class ConditioningField(BaseModel):

invokeai/app/invocations/z_image_denoise.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,29 @@
3232
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
3333
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
3434
from invokeai.backend.util.devices import TorchDevice
35+
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
36+
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
3537
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
3638
from invokeai.backend.z_image.z_image_controlnet_extension import (
3739
ZImageControlNetExtension,
3840
z_image_forward_with_control,
3941
)
42+
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
4043

4144

4245
@invocation(
4346
"z_image_denoise",
4447
title="Denoise - Z-Image",
4548
tags=["image", "z-image"],
4649
category="image",
47-
version="1.1.0",
50+
version="1.2.0",
4851
classification=Classification.Prototype,
4952
)
5053
class ZImageDenoiseInvocation(BaseInvocation):
51-
"""Run the denoising process with a Z-Image model."""
54+
"""Run the denoising process with a Z-Image model.
55+
56+
Supports regional prompting by connecting multiple conditioning inputs with masks.
57+
"""
5258

5359
# If latents is provided, this means we are doing image-to-image.
5460
latents: Optional[LatentsField] = InputField(
@@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
6369
transformer: TransformerField = InputField(
6470
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
6571
)
66-
positive_conditioning: ZImageConditioningField = InputField(
72+
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
6773
description=FieldDescriptions.positive_cond, input=Input.Connection
6874
)
69-
negative_conditioning: Optional[ZImageConditioningField] = InputField(
75+
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
7076
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
7177
)
7278
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
@@ -126,25 +132,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
126132
def _load_text_conditioning(
127133
self,
128134
context: InvocationContext,
129-
conditioning_name: str,
135+
cond_field: ZImageConditioningField | list[ZImageConditioningField],
136+
img_height: int,
137+
img_width: int,
130138
dtype: torch.dtype,
131139
device: torch.device,
132-
) -> torch.Tensor:
133-
"""Load Z-Image text conditioning."""
134-
cond_data = context.conditioning.load(conditioning_name)
135-
if len(cond_data.conditionings) != 1:
136-
raise ValueError(
137-
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
138-
"Ensure you are using the Z-Image text encoder."
139-
)
140-
z_image_conditioning = cond_data.conditionings[0]
141-
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
142-
raise TypeError(
143-
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
144-
"Ensure you are using the Z-Image text encoder."
145-
)
146-
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
147-
return z_image_conditioning.prompt_embeds
140+
) -> list[ZImageTextConditioning]:
141+
"""Load Z-Image text conditioning with optional regional masks.
142+
143+
Args:
144+
context: The invocation context.
145+
cond_field: Single conditioning field or list of fields.
146+
img_height: Height of the image token grid (H // patch_size).
147+
img_width: Width of the image token grid (W // patch_size).
148+
dtype: Target dtype.
149+
device: Target device.
150+
151+
Returns:
152+
List of ZImageTextConditioning objects with embeddings and masks.
153+
"""
154+
# Normalize to a list
155+
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
156+
157+
text_conditionings: list[ZImageTextConditioning] = []
158+
for cond in cond_list:
159+
# Load the text embeddings
160+
cond_data = context.conditioning.load(cond.conditioning_name)
161+
assert len(cond_data.conditionings) == 1
162+
z_image_conditioning = cond_data.conditionings[0]
163+
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
164+
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
165+
prompt_embeds = z_image_conditioning.prompt_embeds
166+
167+
# Load the mask, if provided
168+
mask: torch.Tensor | None = None
169+
if cond.mask is not None:
170+
mask = context.tensors.load(cond.mask.tensor_name)
171+
mask = mask.to(device=device)
172+
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
173+
mask, img_height, img_width, dtype, device
174+
)
175+
176+
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
177+
178+
return text_conditionings
148179

149180
def _get_noise(
150181
self,
@@ -221,14 +252,33 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
221252

222253
transformer_info = context.models.load(self.transformer.transformer)
223254

224-
# Load positive conditioning
225-
pos_prompt_embeds = self._load_text_conditioning(
255+
# Calculate image token grid dimensions
256+
patch_size = 2 # Z-Image uses patch_size=2
257+
latent_height = self.height // LATENT_SCALE_FACTOR
258+
latent_width = self.width // LATENT_SCALE_FACTOR
259+
img_token_height = latent_height // patch_size
260+
img_token_width = latent_width // patch_size
261+
img_seq_len = img_token_height * img_token_width
262+
263+
# Load positive conditioning with regional masks
264+
pos_text_conditionings = self._load_text_conditioning(
226265
context=context,
227-
conditioning_name=self.positive_conditioning.conditioning_name,
266+
cond_field=self.positive_conditioning,
267+
img_height=img_token_height,
268+
img_width=img_token_width,
228269
dtype=inference_dtype,
229270
device=device,
230271
)
231272

273+
# Create regional prompting extension
274+
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
275+
text_conditionings=pos_text_conditionings,
276+
img_seq_len=img_seq_len,
277+
)
278+
279+
# Get the concatenated prompt embeddings for the transformer
280+
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
281+
232282
# Load negative conditioning if provided and guidance_scale != 1.0
233283
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
234284
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
@@ -238,21 +288,22 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
238288
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
239289
)
240290
if do_classifier_free_guidance:
241-
if self.negative_conditioning is None:
242-
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
243-
neg_prompt_embeds = self._load_text_conditioning(
291+
assert self.negative_conditioning is not None
292+
# Load all negative conditionings and concatenate embeddings
293+
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
294+
neg_text_conditionings = self._load_text_conditioning(
244295
context=context,
245-
conditioning_name=self.negative_conditioning.conditioning_name,
296+
cond_field=self.negative_conditioning,
297+
img_height=img_token_height,
298+
img_width=img_token_width,
246299
dtype=inference_dtype,
247300
device=device,
248301
)
249-
250-
# Calculate image sequence length for timestep shifting
251-
patch_size = 2 # Z-Image uses patch_size=2
252-
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
302+
# Concatenate all negative embeddings
303+
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
253304

254305
# Calculate shift based on image sequence length
255-
mu = self._calculate_shift(image_seq_len)
306+
mu = self._calculate_shift(img_seq_len)
256307

257308
# Generate sigma schedule with time shift
258309
sigmas = self._get_sigmas(mu, self.steps)
@@ -443,6 +494,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
443494
)
444495
)
445496

497+
# Apply regional prompting patch if we have regional masks
498+
exit_stack.enter_context(
499+
patch_transformer_for_regional_prompting(
500+
transformer=transformer,
501+
regional_attn_mask=regional_extension.regional_attn_mask,
502+
img_seq_len=img_seq_len,
503+
)
504+
)
505+
446506
# Denoising loop
447507
for step_idx in tqdm(range(total_steps)):
448508
sigma_curr = sigmas[step_idx]

invokeai/app/invocations/z_image_text_encoder.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from contextlib import ExitStack
2-
from typing import Iterator, Tuple
2+
from typing import Iterator, Optional, Tuple
33

44
import torch
55
from transformers import PreTrainedModel, PreTrainedTokenizerBase
66

77
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
8-
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
8+
from invokeai.app.invocations.fields import (
9+
FieldDescriptions,
10+
Input,
11+
InputField,
12+
TensorField,
13+
UIComponent,
14+
ZImageConditioningField,
15+
)
916
from invokeai.app.invocations.model import Qwen3EncoderField
1017
from invokeai.app.invocations.primitives import ZImageConditioningOutput
1118
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -27,25 +34,34 @@
2734
title="Prompt - Z-Image",
2835
tags=["prompt", "conditioning", "z-image"],
2936
category="conditioning",
30-
version="1.0.0",
37+
version="1.1.0",
3138
classification=Classification.Prototype,
3239
)
3340
class ZImageTextEncoderInvocation(BaseInvocation):
34-
"""Encodes and preps a prompt for a Z-Image image."""
41+
"""Encodes and preps a prompt for a Z-Image image.
42+
43+
Supports regional prompting by connecting a mask input.
44+
"""
3545

3646
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
3747
qwen3_encoder: Qwen3EncoderField = InputField(
3848
title="Qwen3 Encoder",
3949
description=FieldDescriptions.qwen3_encoder,
4050
input=Input.Connection,
4151
)
52+
mask: Optional[TensorField] = InputField(
53+
default=None,
54+
description="A mask defining the region that this conditioning prompt applies to.",
55+
)
4256

4357
@torch.no_grad()
4458
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
4559
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
4660
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
4761
conditioning_name = context.conditioning.save(conditioning_data)
48-
return ZImageConditioningOutput.build(conditioning_name)
62+
return ZImageConditioningOutput(
63+
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
64+
)
4965

5066
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
5167
"""Encode prompt using Qwen3 text encoder.

invokeai/backend/z_image/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Z-Image Control Transformer support for InvokeAI
1+
# Z-Image backend utilities
22
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
33
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
44
from invokeai.backend.z_image.z_image_controlnet_extension import (
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Z-Image extensions

0 commit comments

Comments
 (0)