3232from invokeai .backend .stable_diffusion .diffusers_pipeline import PipelineIntermediateState
3333from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ZImageConditioningInfo
3434from invokeai .backend .util .devices import TorchDevice
35+ from invokeai .backend .z_image .extensions .regional_prompting_extension import ZImageRegionalPromptingExtension
36+ from invokeai .backend .z_image .text_conditioning import ZImageTextConditioning
3537from invokeai .backend .z_image .z_image_control_adapter import ZImageControlAdapter
3638from invokeai .backend .z_image .z_image_controlnet_extension import (
3739 ZImageControlNetExtension ,
3840 z_image_forward_with_control ,
3941)
42+ from invokeai .backend .z_image .z_image_transformer_patch import patch_transformer_for_regional_prompting
4043
4144
4245@invocation (
4346 "z_image_denoise" ,
4447 title = "Denoise - Z-Image" ,
4548 tags = ["image" , "z-image" ],
4649 category = "image" ,
47- version = "1.1 .0" ,
50+ version = "1.2 .0" ,
4851 classification = Classification .Prototype ,
4952)
5053class ZImageDenoiseInvocation (BaseInvocation ):
51- """Run the denoising process with a Z-Image model."""
54+ """Run the denoising process with a Z-Image model.
55+
56+ Supports regional prompting by connecting multiple conditioning inputs with masks.
57+ """
5258
5359 # If latents is provided, this means we are doing image-to-image.
5460 latents : Optional [LatentsField ] = InputField (
@@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
6369 transformer : TransformerField = InputField (
6470 description = FieldDescriptions .z_image_model , input = Input .Connection , title = "Transformer"
6571 )
66- positive_conditioning : ZImageConditioningField = InputField (
72+ positive_conditioning : ZImageConditioningField | list [ ZImageConditioningField ] = InputField (
6773 description = FieldDescriptions .positive_cond , input = Input .Connection
6874 )
69- negative_conditioning : Optional [ZImageConditioningField ] = InputField (
75+ negative_conditioning : ZImageConditioningField | list [ZImageConditioningField ] | None = InputField (
7076 default = None , description = FieldDescriptions .negative_cond , input = Input .Connection
7177 )
7278 # Z-Image-Turbo works best without CFG (guidance_scale=1.0)
@@ -126,25 +132,50 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
126132 def _load_text_conditioning (
127133 self ,
128134 context : InvocationContext ,
129- conditioning_name : str ,
135+ cond_field : ZImageConditioningField | list [ZImageConditioningField ],
136+ img_height : int ,
137+ img_width : int ,
130138 dtype : torch .dtype ,
131139 device : torch .device ,
132- ) -> torch .Tensor :
133- """Load Z-Image text conditioning."""
134- cond_data = context .conditioning .load (conditioning_name )
135- if len (cond_data .conditionings ) != 1 :
136- raise ValueError (
137- f"Expected exactly 1 conditioning entry for Z-Image, got { len (cond_data .conditionings )} . "
138- "Ensure you are using the Z-Image text encoder."
139- )
140- z_image_conditioning = cond_data .conditionings [0 ]
141- if not isinstance (z_image_conditioning , ZImageConditioningInfo ):
142- raise TypeError (
143- f"Expected ZImageConditioningInfo, got { type (z_image_conditioning ).__name__ } . "
144- "Ensure you are using the Z-Image text encoder."
145- )
146- z_image_conditioning = z_image_conditioning .to (dtype = dtype , device = device )
147- return z_image_conditioning .prompt_embeds
140+ ) -> list [ZImageTextConditioning ]:
141+ """Load Z-Image text conditioning with optional regional masks.
142+
143+ Args:
144+ context: The invocation context.
145+ cond_field: Single conditioning field or list of fields.
146+ img_height: Height of the image token grid (H // patch_size).
147+ img_width: Width of the image token grid (W // patch_size).
148+ dtype: Target dtype.
149+ device: Target device.
150+
151+ Returns:
152+ List of ZImageTextConditioning objects with embeddings and masks.
153+ """
154+ # Normalize to a list
155+ cond_list = [cond_field ] if isinstance (cond_field , ZImageConditioningField ) else cond_field
156+
157+ text_conditionings : list [ZImageTextConditioning ] = []
158+ for cond in cond_list :
159+ # Load the text embeddings
160+ cond_data = context .conditioning .load (cond .conditioning_name )
161+ assert len (cond_data .conditionings ) == 1
162+ z_image_conditioning = cond_data .conditionings [0 ]
163+ assert isinstance (z_image_conditioning , ZImageConditioningInfo )
164+ z_image_conditioning = z_image_conditioning .to (dtype = dtype , device = device )
165+ prompt_embeds = z_image_conditioning .prompt_embeds
166+
167+ # Load the mask, if provided
168+ mask : torch .Tensor | None = None
169+ if cond .mask is not None :
170+ mask = context .tensors .load (cond .mask .tensor_name )
171+ mask = mask .to (device = device )
172+ mask = ZImageRegionalPromptingExtension .preprocess_regional_prompt_mask (
173+ mask , img_height , img_width , dtype , device
174+ )
175+
176+ text_conditionings .append (ZImageTextConditioning (prompt_embeds = prompt_embeds , mask = mask ))
177+
178+ return text_conditionings
148179
149180 def _get_noise (
150181 self ,
@@ -221,14 +252,33 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
221252
222253 transformer_info = context .models .load (self .transformer .transformer )
223254
224- # Load positive conditioning
225- pos_prompt_embeds = self ._load_text_conditioning (
255+ # Calculate image token grid dimensions
256+ patch_size = 2 # Z-Image uses patch_size=2
257+ latent_height = self .height // LATENT_SCALE_FACTOR
258+ latent_width = self .width // LATENT_SCALE_FACTOR
259+ img_token_height = latent_height // patch_size
260+ img_token_width = latent_width // patch_size
261+ img_seq_len = img_token_height * img_token_width
262+
263+ # Load positive conditioning with regional masks
264+ pos_text_conditionings = self ._load_text_conditioning (
226265 context = context ,
227- conditioning_name = self .positive_conditioning .conditioning_name ,
266+ cond_field = self .positive_conditioning ,
267+ img_height = img_token_height ,
268+ img_width = img_token_width ,
228269 dtype = inference_dtype ,
229270 device = device ,
230271 )
231272
273+ # Create regional prompting extension
274+ regional_extension = ZImageRegionalPromptingExtension .from_text_conditionings (
275+ text_conditionings = pos_text_conditionings ,
276+ img_seq_len = img_seq_len ,
277+ )
278+
279+ # Get the concatenated prompt embeddings for the transformer
280+ pos_prompt_embeds = regional_extension .regional_text_conditioning .prompt_embeds
281+
232282 # Load negative conditioning if provided and guidance_scale != 1.0
233283 # CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
234284 # At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
@@ -238,21 +288,22 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
238288 not math .isclose (self .guidance_scale , 1.0 ) and self .negative_conditioning is not None
239289 )
240290 if do_classifier_free_guidance :
241- if self .negative_conditioning is None :
242- raise ValueError ("Negative conditioning is required when guidance_scale != 1.0" )
243- neg_prompt_embeds = self ._load_text_conditioning (
291+ assert self .negative_conditioning is not None
292+ # Load all negative conditionings and concatenate embeddings
293+ # Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
294+ neg_text_conditionings = self ._load_text_conditioning (
244295 context = context ,
245- conditioning_name = self .negative_conditioning .conditioning_name ,
296+ cond_field = self .negative_conditioning ,
297+ img_height = img_token_height ,
298+ img_width = img_token_width ,
246299 dtype = inference_dtype ,
247300 device = device ,
248301 )
249-
250- # Calculate image sequence length for timestep shifting
251- patch_size = 2 # Z-Image uses patch_size=2
252- image_seq_len = ((self .height // LATENT_SCALE_FACTOR ) * (self .width // LATENT_SCALE_FACTOR )) // (patch_size ** 2 )
302+ # Concatenate all negative embeddings
303+ neg_prompt_embeds = torch .cat ([tc .prompt_embeds for tc in neg_text_conditionings ], dim = 0 )
253304
254305 # Calculate shift based on image sequence length
255- mu = self ._calculate_shift (image_seq_len )
306+ mu = self ._calculate_shift (img_seq_len )
256307
257308 # Generate sigma schedule with time shift
258309 sigmas = self ._get_sigmas (mu , self .steps )
@@ -443,6 +494,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
443494 )
444495 )
445496
497+ # Apply regional prompting patch if we have regional masks
498+ exit_stack .enter_context (
499+ patch_transformer_for_regional_prompting (
500+ transformer = transformer ,
501+ regional_attn_mask = regional_extension .regional_attn_mask ,
502+ img_seq_len = img_seq_len ,
503+ )
504+ )
505+
446506 # Denoising loop
447507 for step_idx in tqdm (range (total_steps )):
448508 sigma_curr = sigmas [step_idx ]
0 commit comments