11from contextlib import ExitStack
2- from typing import Callable , Iterator , Optional , Tuple
2+ from typing import Callable , Iterator , Optional , Tuple , cast
33
4+ import einops
45import torch
56import torchvision .transforms as tv_transforms
7+ from diffusers .models .transformers .transformer_z_image import ZImageTransformer2DModel
8+ from PIL import Image
69from torchvision .transforms .functional import resize as tv_resize
710from tqdm import tqdm
811
1821 WithMetadata ,
1922 ZImageConditioningField ,
2023)
21- from invokeai .app .invocations .model import LoRAField , TransformerField
24+ from invokeai .app .invocations .z_image_control import ZImageControlField
25+ from invokeai .app .invocations .z_image_image_to_latents import ZImageImageToLatentsInvocation
26+ from invokeai .app .invocations .model import LoRAField , TransformerField , VAEField
2227from invokeai .app .invocations .primitives import LatentsOutput
2328from invokeai .app .services .shared .invocation_context import InvocationContext
2429from invokeai .backend .model_manager .taxonomy import BaseModelType , ModelFormat
2934from invokeai .backend .stable_diffusion .diffusers_pipeline import PipelineIntermediateState
3035from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ZImageConditioningInfo
3136from invokeai .backend .util .devices import TorchDevice
37+ from invokeai .backend .z_image .z_image_control_adapter import ZImageControlAdapter
38+ from invokeai .backend .z_image .z_image_control_transformer import ZImageControlTransformer2DModel
3239
3340
3441@invocation (
@@ -73,6 +80,18 @@ class ZImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
7380 # Z-Image-Turbo uses 8 steps by default
7481 steps : int = InputField (default = 8 , gt = 0 , description = "Number of denoising steps. 8 recommended for Z-Image-Turbo." )
7582 seed : int = InputField (default = 0 , description = "Randomness seed for reproducibility." )
83+ # Z-Image Control support
84+ control : Optional [ZImageControlField ] = InputField (
85+ default = None ,
86+ description = "Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD)." ,
87+ input = Input .Connection ,
88+ )
89+ # VAE for encoding control images (required when using control)
90+ vae : Optional [VAEField ] = InputField (
91+ default = None ,
92+ description = FieldDescriptions .vae + " Required for control conditioning." ,
93+ input = Input .Connection ,
94+ )
7695
7796 @torch .no_grad ()
7897 def invoke (self , context : InvocationContext ) -> LatentsOutput :
@@ -283,27 +302,121 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
283302 )
284303
285304 with ExitStack () as exit_stack :
286- # Load transformer and apply LoRA patches
287- (cached_weights , transformer ) = exit_stack .enter_context (transformer_info .model_on_device ())
288-
289305 # Get transformer config to determine if it's quantized
290306 transformer_config = context .models .get_config (self .transformer .transformer )
291307
292308 # Determine if the model is quantized.
293309 # If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
294310 # slower inference than direct patching, but is agnostic to the quantization format.
295- if transformer_config .format in [ModelFormat .Diffusers ]:
311+ if transformer_config .format in [ModelFormat .Diffusers , ModelFormat . Checkpoint ]:
296312 model_is_quantized = False
297313 elif transformer_config .format in [ModelFormat .GGUFQuantized ]:
298314 model_is_quantized = True
299315 else :
300316 raise ValueError (f"Unsupported Z-Image model format: { transformer_config .format } " )
301317
302- # Apply LoRA models to the transformer.
318+ # Load control adapter and prepare combined transformer if control is provided
319+ control_context : list [torch .Tensor ] | None = None
320+ control_context_scale = 0.75
321+ begin_step_percent = 0.0
322+ end_step_percent = 1.0
323+ cached_weights = None
324+
325+ if self .control is not None :
326+ # Simplified approach: Create model and load weights directly
327+ # 1. Get base transformer config (stays on CPU)
328+ # 2. Create ZImageControlTransformer2DModel
329+ # 3. Load base weights directly (strict=False)
330+ # 4. Load adapter control weights on top
331+ # 5. Move to GPU
332+
333+ # Get base transformer config (stays on CPU)
334+ base_transformer = cast (ZImageTransformer2DModel , transformer_info .model )
335+ base_config = base_transformer .config
336+
337+ # Load control adapter (stays on CPU)
338+ control_model_info = context .models .load (self .control .control_model )
339+ control_adapter = control_model_info .model
340+ assert isinstance (control_adapter , ZImageControlAdapter )
341+
342+ # Create ZImageControlTransformer2DModel
343+ control_transformer = ZImageControlTransformer2DModel (
344+ all_patch_size = base_config .all_patch_size ,
345+ all_f_patch_size = base_config .all_f_patch_size ,
346+ in_channels = base_config .in_channels ,
347+ dim = base_config .dim ,
348+ n_layers = base_config .n_layers ,
349+ n_refiner_layers = base_config .n_refiner_layers ,
350+ n_heads = base_config .n_heads ,
351+ n_kv_heads = base_config .n_kv_heads ,
352+ norm_eps = base_config .norm_eps ,
353+ qk_norm = base_config .qk_norm ,
354+ cap_feat_dim = base_config .cap_feat_dim ,
355+ rope_theta = base_config .rope_theta ,
356+ t_scale = base_config .t_scale ,
357+ axes_dims = base_config .axes_dims ,
358+ axes_lens = base_config .axes_lens ,
359+ )
360+
361+ # Load base transformer weights directly (strict=False handles missing control keys)
362+ control_transformer .load_state_dict (base_transformer .state_dict (), strict = False , assign = True )
363+
364+ # Load control adapter weights on top (only control-specific keys)
365+ # Filter to only control_ prefixed keys to avoid overwriting x_pad_token
366+ adapter_control_weights = {
367+ k : v for k , v in control_adapter .state_dict ().items () if k .startswith ("control_" )
368+ }
369+ control_transformer .load_state_dict (adapter_control_weights , strict = False , assign = True )
370+
371+ # Move to device
372+ control_transformer = control_transformer .to (device = device , dtype = inference_dtype )
373+ active_transformer = control_transformer
374+
375+ # Load and prepare control image - must be VAE-encoded!
376+ if self .vae is None :
377+ raise ValueError ("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input." )
378+
379+ control_image = context .images .get_pil (self .control .image_name )
380+
381+ # Resize control image to match output dimensions
382+ control_image = control_image .convert ("RGB" )
383+ control_image = control_image .resize ((self .width , self .height ), Image .Resampling .LANCZOS )
384+
385+ # Convert to tensor format for VAE encoding
386+ from invokeai .backend .stable_diffusion .diffusers_pipeline import image_resized_to_grid_as_tensor
387+
388+ control_image_tensor = image_resized_to_grid_as_tensor (control_image )
389+ if control_image_tensor .dim () == 3 :
390+ control_image_tensor = einops .rearrange (control_image_tensor , "c h w -> 1 c h w" )
391+
392+ # Encode control image through VAE to get latents
393+ vae_info = context .models .load (self .vae .vae )
394+ control_latents = ZImageImageToLatentsInvocation .vae_encode (
395+ vae_info = vae_info ,
396+ image_tensor = control_image_tensor ,
397+ )
398+
399+ # Move to inference device/dtype
400+ control_latents = control_latents .to (device = device , dtype = inference_dtype )
401+
402+ # Add frame dimension: [B, C, H, W] -> [B, C, 1, H, W]
403+ control_latents = control_latents .unsqueeze (2 )
404+ # Convert to list format expected by transformer
405+ control_context = list (control_latents .unbind (dim = 0 ))
406+ control_context_scale = self .control .control_context_scale
407+ begin_step_percent = self .control .begin_step_percent
408+ end_step_percent = self .control .end_step_percent
409+ else :
410+ # No control - load transformer normally
411+ (cached_weights , transformer ) = exit_stack .enter_context (transformer_info .model_on_device ())
412+ active_transformer = transformer
413+
414+ # Apply LoRA models to the active transformer.
303415 # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
416+ # cached_weights is None when using control (since we create a new model), otherwise it's from model_on_device
304417 exit_stack .enter_context (
305418 LayerPatcher .apply_smart_model_patches (
306- model = transformer ,
419+ model = active_transformer ,
307420 patches = self ._lora_iterator (context ),
308421 prefix = Z_IMAGE_LORA_TRANSFORMER_PREFIX ,
309422 dtype = inference_dtype ,
@@ -326,28 +439,55 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
326439 # Run transformer for positive prediction
327440 # Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
328441 # Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
329- latent_model_input = latents .to (transformer .dtype )
442+ latent_model_input = latents .to (active_transformer .dtype )
330443 latent_model_input = latent_model_input .unsqueeze (2 ) # Add frame dimension
331444 latent_model_input_list = list (latent_model_input .unbind (dim = 0 ))
332445
333- # Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
334- model_output = transformer (
335- x = latent_model_input_list ,
336- t = timestep ,
337- cap_feats = [pos_prompt_embeds ],
446+ # Determine if control should be applied at this step
447+ step_percent = step_idx / total_steps
448+ apply_control = (
449+ control_context is not None
450+ and step_percent >= begin_step_percent
451+ and step_percent <= end_step_percent
338452 )
453+
454+ # Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
455+ # If control is active, pass control_context to the control transformer
456+ if apply_control :
457+ model_output = active_transformer (
458+ x = latent_model_input_list ,
459+ t = timestep ,
460+ cap_feats = [pos_prompt_embeds ],
461+ control_context = control_context ,
462+ control_context_scale = control_context_scale ,
463+ )
464+ else :
465+ model_output = active_transformer (
466+ x = latent_model_input_list ,
467+ t = timestep ,
468+ cap_feats = [pos_prompt_embeds ],
469+ )
339470 model_out_list = model_output [0 ] # Extract list of tensors from tuple
340471 noise_pred_cond = torch .stack ([t .float () for t in model_out_list ], dim = 0 )
341472 noise_pred_cond = noise_pred_cond .squeeze (2 ) # Remove frame dimension
342473 noise_pred_cond = - noise_pred_cond # Z-Image uses v-prediction with negation
343474
344475 # Apply CFG if enabled
345476 if do_classifier_free_guidance and neg_prompt_embeds is not None :
346- model_output_uncond = transformer (
347- x = latent_model_input_list ,
348- t = timestep ,
349- cap_feats = [neg_prompt_embeds ],
350- )
477+ if apply_control :
478+ model_output_uncond = active_transformer (
479+ x = latent_model_input_list ,
480+ t = timestep ,
481+ cap_feats = [neg_prompt_embeds ],
482+ control_context = control_context ,
483+ control_context_scale = control_context_scale ,
484+ )
485+ else :
486+ model_output_uncond = active_transformer (
487+ x = latent_model_input_list ,
488+ t = timestep ,
489+ cap_feats = [neg_prompt_embeds ],
490+ )
351491 model_out_list_uncond = model_output_uncond [0 ] # Extract list of tensors from tuple
352492 noise_pred_uncond = torch .stack ([t .float () for t in model_out_list_uncond ], dim = 0 )
353493 noise_pred_uncond = noise_pred_uncond .squeeze (2 )
0 commit comments