@@ -35,6 +35,7 @@ def __init__(
3535 device = None , dtype = None , operations = None
3636 ):
3737 super ().__init__ ()
38+ self .additional_in_dim = additional_in_dim
3839 self .img_in = operations .Linear (in_dim + additional_in_dim , dim , device = device , dtype = dtype )
3940 self .controlnet_blocks = torch .nn .ModuleList (
4041 [
@@ -44,7 +45,7 @@ def __init__(
4445 )
4546
4647 def process_input_latent_image (self , latent_image ):
47- latent_image = comfy .latent_formats .Wan21 ().process_in (latent_image )
48+ latent_image [:, : 16 ] = comfy .latent_formats .Wan21 ().process_in (latent_image [:, : 16 ] )
4849 patch_size = 2
4950 hidden_states = comfy .ldm .common_dit .pad_to_patch_size (latent_image , (1 , patch_size , patch_size ))
5051 orig_shape = hidden_states .shape
@@ -73,19 +74,33 @@ def load_model_patch(self, name):
7374 sd = comfy .utils .load_torch_file (model_patch_path , safe_load = True )
7475 dtype = comfy .utils .weight_dtype (sd )
7576 # TODO: this node will work with more types of model patches
76- model = QwenImageBlockWiseControlNet (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
77+ additional_in_dim = sd ["img_in.weight" ].shape [1 ] - 64
78+ model = QwenImageBlockWiseControlNet (additional_in_dim = additional_in_dim , device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
7779 model .load_state_dict (sd )
7880 model = comfy .model_patcher .ModelPatcher (model , load_device = comfy .model_management .get_torch_device (), offload_device = comfy .model_management .unet_offload_device ())
7981 return (model ,)
8082
8183
8284class DiffSynthCnetPatch :
83- def __init__ (self , model_patch , vae , image , strength ):
84- self .encoded_image = model_patch .model .process_input_latent_image (vae .encode (image ))
85+ def __init__ (self , model_patch , vae , image , strength , mask = None ):
8586 self .model_patch = model_patch
8687 self .vae = vae
8788 self .image = image
8889 self .strength = strength
90+ self .mask = mask
91+ self .encoded_image = model_patch .model .process_input_latent_image (self .encode_latent_cond (image ))
92+
93+ def encode_latent_cond (self , image ):
94+ latent_image = self .vae .encode (image )
95+ if self .model_patch .model .additional_in_dim > 0 :
96+ if self .mask is None :
97+ mask_ = torch .ones_like (latent_image )[:, :self .model_patch .model .additional_in_dim // 4 ]
98+ else :
99+ mask_ = comfy .utils .common_upscale (self .mask .mean (dim = 1 , keepdim = True ), latent_image .shape [- 1 ], latent_image .shape [- 2 ], "bilinear" , "none" )
100+
101+ return torch .cat ([latent_image , mask_ ], dim = 1 )
102+ else :
103+ return latent_image
89104
90105 def __call__ (self , kwargs ):
91106 x = kwargs .get ("x" )
@@ -95,7 +110,7 @@ def __call__(self, kwargs):
95110 spacial_compression = self .vae .spacial_compression_encode ()
96111 image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" )
97112 loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
98- self .encoded_image = self .model_patch .model .process_input_latent_image (self .vae . encode (image_scaled .movedim (1 , - 1 )))
113+ self .encoded_image = self .model_patch .model .process_input_latent_image (self .encode_latent_cond (image_scaled .movedim (1 , - 1 )))
99114 comfy .model_management .load_models_gpu (loaded_models )
100115
101116 img = img + (self .model_patch .model .control_block (img , self .encoded_image .to (img .dtype ), block_index ) * self .strength )
@@ -118,17 +133,25 @@ def INPUT_TYPES(s):
118133 "vae" : ("VAE" ,),
119134 "image" : ("IMAGE" ,),
120135 "strength" : ("FLOAT" , {"default" : 1.0 , "min" : - 10.0 , "max" : 10.0 , "step" : 0.01 }),
121- }}
136+ },
137+ "optional" : {"mask" : ("MASK" ,)}}
122138 RETURN_TYPES = ("MODEL" ,)
123139 FUNCTION = "diffsynth_controlnet"
124140 EXPERIMENTAL = True
125141
126142 CATEGORY = "advanced/loaders/qwen"
127143
128- def diffsynth_controlnet (self , model , model_patch , vae , image , strength ):
144+ def diffsynth_controlnet (self , model , model_patch , vae , image , strength , mask = None ):
129145 model_patched = model .clone ()
130146 image = image [:, :, :, :3 ]
131- model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength ))
147+ if mask is not None :
148+ if mask .ndim == 3 :
149+ mask = mask .unsqueeze (1 )
150+ if mask .ndim == 4 :
151+ mask = mask .unsqueeze (2 )
152+ mask = 1.0 - mask
153+
154+ model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
132155 return (model_patched ,)
133156
134157
0 commit comments