@@ -333,57 +333,74 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
333333 begin_step_percent = 0.0
334334 end_step_percent = 1.0
335335 cached_weights = None
336+ control_in_dim = 16 # Default, will be set from adapter config if control is used
336337
337338 if self .control is not None :
338- # Simplified approach: Create model and load weights directly
339- # 1. Get base transformer config (stays on CPU)
340- # 2. Create ZImageControlTransformer2DModel
341- # 3. Load base weights directly (strict=False)
342- # 4. Load adapter control weights on top
343- # 5. Move to GPU
344-
345- # Get base transformer config (stays on CPU)
339+ # Load base transformer config (NOT to GPU yet - just get the model reference)
346340 base_transformer = cast (ZImageTransformer2DModel , transformer_info .model )
347341 base_config = base_transformer .config
348342
349- # Load control adapter (stays on CPU)
343+ # Load control adapter
350344 control_model_info = context .models .load (self .control .control_model )
351345 control_adapter = control_model_info .model
352346 assert isinstance (control_adapter , ZImageControlAdapter )
353347
354- # Create ZImageControlTransformer2DModel
355- control_transformer = ZImageControlTransformer2DModel (
356- all_patch_size = base_config .all_patch_size ,
357- all_f_patch_size = base_config .all_f_patch_size ,
358- in_channels = base_config .in_channels ,
359- dim = base_config .dim ,
360- n_layers = base_config .n_layers ,
361- n_refiner_layers = base_config .n_refiner_layers ,
362- n_heads = base_config .n_heads ,
363- n_kv_heads = base_config .n_kv_heads ,
364- norm_eps = base_config .norm_eps ,
365- qk_norm = base_config .qk_norm ,
366- cap_feat_dim = base_config .cap_feat_dim ,
367- rope_theta = base_config .rope_theta ,
368- t_scale = base_config .t_scale ,
369- axes_dims = base_config .axes_dims ,
370- axes_lens = base_config .axes_lens ,
348+ # Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
349+ adapter_config = control_adapter .config
350+ control_in_dim = adapter_config .get ("control_in_dim" , 16 )
351+ num_control_blocks = adapter_config .get ("num_control_blocks" , 6 )
352+ n_refiner_layers = adapter_config .get ("n_refiner_layers" , 2 )
353+
354+ # Calculate control_layers_places based on num_control_blocks
355+ control_layers_places = [i * 2 for i in range (num_control_blocks )]
356+
357+ # Log control configuration for debugging
358+ version = "V2.0" if control_in_dim > 16 else "V1"
359+ context .util .signal_progress (
360+ f"Using Z-Image ControlNet { version } : control_in_dim={ control_in_dim } , "
361+ f"num_blocks={ num_control_blocks } , scale={ self .control .control_context_scale } "
371362 )
372363
373- # Load base transformer weights directly (strict=False handles missing control keys)
364+ # Create control transformer structure with empty weights
365+ import accelerate
366+
367+ with accelerate .init_empty_weights ():
368+ control_transformer = ZImageControlTransformer2DModel (
369+ control_layers_places = control_layers_places ,
370+ control_in_dim = control_in_dim ,
371+ all_patch_size = base_config .all_patch_size ,
372+ all_f_patch_size = base_config .all_f_patch_size ,
373+ in_channels = base_config .in_channels ,
374+ dim = base_config .dim ,
375+ n_layers = base_config .n_layers ,
376+ n_refiner_layers = n_refiner_layers ,
377+ n_heads = base_config .n_heads ,
378+ n_kv_heads = base_config .n_kv_heads ,
379+ norm_eps = base_config .norm_eps ,
380+ qk_norm = base_config .qk_norm ,
381+ cap_feat_dim = base_config .cap_feat_dim ,
382+ rope_theta = base_config .rope_theta ,
383+ t_scale = base_config .t_scale ,
384+ axes_dims = base_config .axes_dims ,
385+ axes_lens = base_config .axes_lens ,
386+ )
387+
388+ # Load base weights with assign=True (assigns tensors directly, no copy of data)
374389 control_transformer .load_state_dict (base_transformer .state_dict (), strict = False , assign = True )
375390
376- # Load control adapter weights on top (only control-specific keys)
377- # Filter to only control_ prefixed keys to avoid overwriting x_pad_token
378- adapter_control_weights = {
379- k : v for k , v in control_adapter .state_dict ().items () if k .startswith ("control_" )
380- }
381- control_transformer .load_state_dict (adapter_control_weights , strict = False , assign = True )
391+ # Load control adapter weights on top
392+ adapter_weights = {k : v for k , v in control_adapter .state_dict ().items () if k .startswith ("control_" )}
393+ control_transformer .load_state_dict (adapter_weights , strict = False , assign = True )
382394
383- # Move to device
395+ # Move combined model to device
384396 control_transformer = control_transformer .to (device = device , dtype = inference_dtype )
385397 active_transformer = control_transformer
386398
399+ # Clean up to save memory # need to check
400+ #del control_adapter
401+ #if torch.cuda.is_available():
402+ # torch.cuda.empty_cache()
403+
387404 # Load and prepare control image - must be VAE-encoded!
388405 if self .vae is None :
389406 raise ValueError ("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input." )
@@ -413,8 +430,23 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
413430
414431 # Add frame dimension: [B, C, H, W] -> [B, C, 1, H, W]
415432 control_latents = control_latents .unsqueeze (2 )
416- # Convert to list format expected by transformer
433+
434+ # Prepare control_context based on control_in_dim
435+ # V1: 16 channels (just control latents)
436+ # V2.0: 33 channels (control latents + zero padding)
437+ # Following diffusers approach: simple zero-padding to match control_in_dim
438+ b , c , f , h , w = control_latents .shape
439+ if c < control_in_dim :
440+ # Pad with zeros to match control_in_dim (diffusers approach)
441+ padding_channels = control_in_dim - c
442+ zero_padding = torch .zeros (
443+ (b , padding_channels , f , h , w ),
444+ device = device ,
445+ dtype = inference_dtype ,
446+ )
447+ control_latents = torch .cat ([control_latents , zero_padding ], dim = 1 )
417448 control_context = list (control_latents .unbind (dim = 0 ))
449+
418450 control_context_scale = self .control .control_context_scale
419451 begin_step_percent = self .control .begin_step_percent
420452 end_step_percent = self .control .end_step_percent
@@ -425,7 +457,8 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
425457
426458 # Apply LoRA models to the active transformer.
427459 # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
428- # cached_weights is None when using control (since we create a new model), otherwise it's from model_on_device
460+ # cached_weights is None when using control (since we create a new combined model),
461+ # otherwise it comes from model_on_device() context.
429462 exit_stack .enter_context (
430463 LayerPatcher .apply_smart_model_patches (
431464 model = active_transformer ,
@@ -457,15 +490,16 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
457490
458491 # Determine if control should be applied at this step
459492 step_percent = step_idx / total_steps
493+ use_control = self .control is not None
460494 apply_control = (
461- control_context is not None
495+ use_control
462496 and step_percent >= begin_step_percent
463497 and step_percent <= end_step_percent
464498 )
465499
466500 # Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
467501 # If control is active, pass control_context to the control transformer
468- if apply_control :
502+ if apply_control and control_context is not None :
469503 model_output = active_transformer (
470504 x = latent_model_input_list ,
471505 t = timestep ,
@@ -486,7 +520,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
486520
487521 # Apply CFG if enabled
488522 if do_classifier_free_guidance and neg_prompt_embeds is not None :
489- if apply_control :
523+ if apply_control and control_context is not None :
490524 model_output_uncond = active_transformer (
491525 x = latent_model_input_list ,
492526 t = timestep ,
0 commit comments