@@ -253,7 +253,10 @@ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
253
253
to_concat = []
254
254
for c in self .extra_concat_orig :
255
255
c = c .to (self .cond_hint .device )
256
- c = comfy .utils .common_upscale (c , self .cond_hint .shape [3 ], self .cond_hint .shape [2 ], self .upscale_algorithm , "center" )
256
+ c = comfy .utils .common_upscale (c , self .cond_hint .shape [- 1 ], self .cond_hint .shape [- 2 ], self .upscale_algorithm , "center" )
257
+ if c .ndim < self .cond_hint .ndim :
258
+ c = c .unsqueeze (2 )
259
+ c = comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [2 ], dim = 2 )
257
260
to_concat .append (comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [0 ]))
258
261
self .cond_hint = torch .cat ([self .cond_hint ] + to_concat , dim = 1 )
259
262
@@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
585
588
586
589
def load_controlnet_qwen_instantx (sd , model_options = {}):
587
590
model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (sd , model_options = model_options )
588
- control_model = comfy .ldm .qwen_image .controlnet .QwenImageControlNetModel (operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
591
+ control_latent_channels = sd .get ("controlnet_x_embedder.weight" ).shape [1 ]
592
+
593
+ extra_condition_channels = 0
594
+ concat_mask = False
595
+ if control_latent_channels == 68 : #inpaint controlnet
596
+ extra_condition_channels = control_latent_channels - 64
597
+ concat_mask = True
598
+ control_model = comfy .ldm .qwen_image .controlnet .QwenImageControlNetModel (extra_condition_channels = extra_condition_channels , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
589
599
control_model = controlnet_load_state_dict (control_model , sd )
590
600
latent_format = comfy .latent_formats .Wan21 ()
591
601
extra_conds = []
592
- control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
602
+ control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , concat_mask = concat_mask , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
593
603
return control
594
604
595
605
def convert_mistoline (sd ):
0 commit comments