1+ import math
12from contextlib import ExitStack
23from typing import Callable , Iterator , Optional , Tuple
34
5+ import einops
46import torch
57import torchvision .transforms as tv_transforms
8+ from PIL import Image
69from torchvision .transforms .functional import resize as tv_resize
710from tqdm import tqdm
811
1619 LatentsField ,
1720 ZImageConditioningField ,
1821)
19- from invokeai .app .invocations .model import TransformerField
22+ from invokeai .app .invocations .model import TransformerField , VAEField
2023from invokeai .app .invocations .primitives import LatentsOutput
24+ from invokeai .app .invocations .z_image_control import ZImageControlField
25+ from invokeai .app .invocations .z_image_image_to_latents import ZImageImageToLatentsInvocation
2126from invokeai .app .services .shared .invocation_context import InvocationContext
2227from invokeai .backend .model_manager .taxonomy import BaseModelType , ModelFormat
2328from invokeai .backend .patches .layer_patcher import LayerPatcher
2732from invokeai .backend .stable_diffusion .diffusers_pipeline import PipelineIntermediateState
2833from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ZImageConditioningInfo
2934from invokeai .backend .util .devices import TorchDevice
35+ from invokeai .backend .z_image .z_image_control_adapter import ZImageControlAdapter
36+ from invokeai .backend .z_image .z_image_controlnet_extension import (
37+ ZImageControlNetExtension ,
38+ z_image_forward_with_control ,
39+ )
3040
3141
3242@invocation (
@@ -59,18 +69,31 @@ class ZImageDenoiseInvocation(BaseInvocation):
5969 negative_conditioning : Optional [ZImageConditioningField ] = InputField (
6070 default = None , description = FieldDescriptions .negative_cond , input = Input .Connection
6171 )
62- # Z-Image-Turbo uses guidance_scale=0.0 by default (no CFG )
72+ # Z-Image-Turbo works best without CFG (guidance_scale=1.0 )
6373 guidance_scale : float = InputField (
64- default = 0.0 ,
65- ge = 0.0 ,
66- description = "Guidance scale for classifier-free guidance. Use 0.0 for Z-Image-Turbo." ,
74+ default = 1.0 ,
75+ ge = 1.0 ,
76+ description = "Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
77+ "Values > 1.0 amplify guidance." ,
6778 title = "Guidance Scale" ,
6879 )
6980 width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
7081 height : int = InputField (default = 1024 , multiple_of = 16 , description = "Height of the generated image." )
7182 # Z-Image-Turbo uses 8 steps by default
7283 steps : int = InputField (default = 8 , gt = 0 , description = "Number of denoising steps. 8 recommended for Z-Image-Turbo." )
7384 seed : int = InputField (default = 0 , description = "Randomness seed for reproducibility." )
85+ # Z-Image Control support
86+ control : Optional [ZImageControlField ] = InputField (
87+ default = None ,
88+ description = "Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD)." ,
89+ input = Input .Connection ,
90+ )
91+ # VAE for encoding control images (required when using control)
92+ vae : Optional [VAEField ] = InputField (
93+ default = None ,
94+ description = FieldDescriptions .vae + " Required for control conditioning." ,
95+ input = Input .Connection ,
96+ )
7497
7598 @torch .no_grad ()
7699 def invoke (self , context : InvocationContext ) -> LatentsOutput :
@@ -206,12 +229,17 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
206229 device = device ,
207230 )
208231
209- # Load negative conditioning if provided and guidance_scale > 0
232+ # Load negative conditioning if provided and guidance_scale != 1.0
233+ # CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
234+ # At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
235+ # This matches FLUX's convention where 1.0 means "no CFG"
210236 neg_prompt_embeds : torch .Tensor | None = None
211- do_classifier_free_guidance = self .guidance_scale > 0.0 and self .negative_conditioning is not None
237+ do_classifier_free_guidance = (
238+ not math .isclose (self .guidance_scale , 1.0 ) and self .negative_conditioning is not None
239+ )
212240 if do_classifier_free_guidance :
213241 if self .negative_conditioning is None :
214- raise ValueError ("Negative conditioning is required when guidance_scale > 0" )
242+ raise ValueError ("Negative conditioning is required when guidance_scale != 1. 0" )
215243 neg_prompt_embeds = self ._load_text_conditioning (
216244 context = context ,
217245 conditioning_name = self .negative_conditioning .conditioning_name ,
@@ -293,9 +321,6 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
293321 )
294322
295323 with ExitStack () as exit_stack :
296- # Load transformer and apply LoRA patches
297- (cached_weights , transformer ) = exit_stack .enter_context (transformer_info .model_on_device ())
298-
299324 # Get transformer config to determine if it's quantized
300325 transformer_config = context .models .get_config (self .transformer .transformer )
301326
@@ -309,6 +334,102 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
309334 else :
310335 raise ValueError (f"Unsupported Z-Image model format: { transformer_config .format } " )
311336
337+ # Load transformer - always use base transformer, control is handled via extension
338+ (cached_weights , transformer ) = exit_stack .enter_context (transformer_info .model_on_device ())
339+
340+ # Prepare control extension if control is provided
341+ control_extension : ZImageControlNetExtension | None = None
342+
343+ if self .control is not None :
344+ # Load control adapter using context manager (proper GPU memory management)
345+ control_model_info = context .models .load (self .control .control_model )
346+ (_ , control_adapter ) = exit_stack .enter_context (control_model_info .model_on_device ())
347+ assert isinstance (control_adapter , ZImageControlAdapter )
348+
349+ # Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
350+ adapter_config = control_adapter .config
351+ control_in_dim = adapter_config .get ("control_in_dim" , 16 )
352+ num_control_blocks = adapter_config .get ("num_control_blocks" , 6 )
353+
354+ # Log control configuration for debugging
355+ version = "V2.0" if control_in_dim > 16 else "V1"
356+ context .util .signal_progress (
357+ f"Using Z-Image ControlNet { version } (Extension): control_in_dim={ control_in_dim } , "
358+ f"num_blocks={ num_control_blocks } , scale={ self .control .control_context_scale } "
359+ )
360+
361+ # Load and prepare control image - must be VAE-encoded!
362+ if self .vae is None :
363+ raise ValueError ("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input." )
364+
365+ control_image = context .images .get_pil (self .control .image_name )
366+
367+ # Resize control image to match output dimensions
368+ control_image = control_image .convert ("RGB" )
369+ control_image = control_image .resize ((self .width , self .height ), Image .Resampling .LANCZOS )
370+
371+ # Convert to tensor format for VAE encoding
372+ from invokeai .backend .stable_diffusion .diffusers_pipeline import image_resized_to_grid_as_tensor
373+
374+ control_image_tensor = image_resized_to_grid_as_tensor (control_image )
375+ if control_image_tensor .dim () == 3 :
376+ control_image_tensor = einops .rearrange (control_image_tensor , "c h w -> 1 c h w" )
377+
378+ # Encode control image through VAE to get latents
379+ vae_info = context .models .load (self .vae .vae )
380+ control_latents = ZImageImageToLatentsInvocation .vae_encode (
381+ vae_info = vae_info ,
382+ image_tensor = control_image_tensor ,
383+ )
384+
385+ # Move to inference device/dtype
386+ control_latents = control_latents .to (device = device , dtype = inference_dtype )
387+
388+ # Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
389+ control_latents = control_latents .squeeze (0 ).unsqueeze (1 )
390+
391+ # Prepare control_cond based on control_in_dim
392+ # V1: 16 channels (just control latents)
393+ # V2.0: 33 channels = 16 control + 16 reference + 1 mask
394+ # - Channels 0-15: control image latents (from VAE encoding)
395+ # - Channels 16-31: reference/inpaint image latents (zeros for pure control)
396+ # - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
397+ # For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
398+ c , f , h , w = control_latents .shape
399+ if c < control_in_dim :
400+ padding_channels = control_in_dim - c
401+ if padding_channels == 17 :
402+ # V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
403+ ref_padding = torch .zeros (
404+ (16 , f , h , w ),
405+ device = device ,
406+ dtype = inference_dtype ,
407+ )
408+ # Mask channel = 1.0 means "don't inpaint this region, use control signal"
409+ mask_channel = torch .ones (
410+ (1 , f , h , w ),
411+ device = device ,
412+ dtype = inference_dtype ,
413+ )
414+ control_latents = torch .cat ([control_latents , ref_padding , mask_channel ], dim = 0 )
415+ else :
416+ # Generic padding with zeros for other cases
417+ zero_padding = torch .zeros (
418+ (padding_channels , f , h , w ),
419+ device = device ,
420+ dtype = inference_dtype ,
421+ )
422+ control_latents = torch .cat ([control_latents , zero_padding ], dim = 0 )
423+
424+ # Create control extension (adapter is already on device from model_on_device)
425+ control_extension = ZImageControlNetExtension (
426+ control_adapter = control_adapter ,
427+ control_cond = control_latents ,
428+ weight = self .control .control_context_scale ,
429+ begin_step_percent = self .control .begin_step_percent ,
430+ end_step_percent = self .control .end_step_percent ,
431+ )
432+
312433 # Apply LoRA models to the transformer.
313434 # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
314435 exit_stack .enter_context (
@@ -340,25 +461,48 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
340461 latent_model_input = latent_model_input .unsqueeze (2 ) # Add frame dimension
341462 latent_model_input_list = list (latent_model_input .unbind (dim = 0 ))
342463
343- # Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
344- model_output = transformer (
345- x = latent_model_input_list ,
346- t = timestep ,
347- cap_feats = [pos_prompt_embeds ],
348- )
349- model_out_list = model_output [0 ] # Extract list of tensors from tuple
464+ # Determine if control should be applied at this step
465+ apply_control = control_extension is not None and control_extension .should_apply (step_idx , total_steps )
466+
467+ # Run forward pass - use custom forward with control if extension is active
468+ if apply_control :
469+ model_out_list , _ = z_image_forward_with_control (
470+ transformer = transformer ,
471+ x = latent_model_input_list ,
472+ t = timestep ,
473+ cap_feats = [pos_prompt_embeds ],
474+ control_extension = control_extension ,
475+ )
476+ else :
477+ model_output = transformer (
478+ x = latent_model_input_list ,
479+ t = timestep ,
480+ cap_feats = [pos_prompt_embeds ],
481+ )
482+ model_out_list = model_output [0 ] # Extract list of tensors from tuple
483+
350484 noise_pred_cond = torch .stack ([t .float () for t in model_out_list ], dim = 0 )
351485 noise_pred_cond = noise_pred_cond .squeeze (2 ) # Remove frame dimension
352486 noise_pred_cond = - noise_pred_cond # Z-Image uses v-prediction with negation
353487
354488 # Apply CFG if enabled
355489 if do_classifier_free_guidance and neg_prompt_embeds is not None :
356- model_output_uncond = transformer (
357- x = latent_model_input_list ,
358- t = timestep ,
359- cap_feats = [neg_prompt_embeds ],
360- )
361- model_out_list_uncond = model_output_uncond [0 ] # Extract list of tensors from tuple
490+ if apply_control :
491+ model_out_list_uncond , _ = z_image_forward_with_control (
492+ transformer = transformer ,
493+ x = latent_model_input_list ,
494+ t = timestep ,
495+ cap_feats = [neg_prompt_embeds ],
496+ control_extension = control_extension ,
497+ )
498+ else :
499+ model_output_uncond = transformer (
500+ x = latent_model_input_list ,
501+ t = timestep ,
502+ cap_feats = [neg_prompt_embeds ],
503+ )
504+ model_out_list_uncond = model_output_uncond [0 ] # Extract list of tensors from tuple
505+
362506 noise_pred_uncond = torch .stack ([t .float () for t in model_out_list_uncond ], dim = 0 )
363507 noise_pred_uncond = noise_pred_uncond .squeeze (2 )
364508 noise_pred_uncond = - noise_pred_uncond
0 commit comments