@@ -243,7 +243,13 @@ def load_model_patch(self, name):
243243 model = SigLIPMultiFeatProjModel (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
244244 elif 'control_all_x_embedder.2-1.weight' in sd : # alipai z image fun controlnet
245245 sd = z_image_convert (sd )
246- model = comfy .ldm .lumina .controlnet .ZImage_Control (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
246+ config = {}
247+ if 'control_layers.14.adaLN_modulation.0.weight' in sd :
248+ config ['n_control_layers' ] = 15
249+ config ['additional_in_dim' ] = 17
250+ config ['refiner_control' ] = True
251+ config ['broken' ] = True
252+ model = comfy .ldm .lumina .controlnet .ZImage_Control (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast , ** config )
247253
248254 model .load_state_dict (sd )
249255 model = comfy .model_patcher .ModelPatcher (model , load_device = comfy .model_management .get_torch_device (), offload_device = comfy .model_management .unet_offload_device ())
@@ -297,56 +303,86 @@ def models(self):
297303 return [self .model_patch ]
298304
299305class ZImageControlPatch :
300- def __init__ (self , model_patch , vae , image , strength ):
306+ def __init__ (self , model_patch , vae , image , strength , inpaint_image = None , mask = None ):
301307 self .model_patch = model_patch
302308 self .vae = vae
303309 self .image = image
310+ self .inpaint_image = inpaint_image
311+ self .mask = mask
304312 self .strength = strength
305313 self .encoded_image = self .encode_latent_cond (image )
306314 self .encoded_image_size = (image .shape [1 ], image .shape [2 ])
307315 self .temp_data = None
308316
309- def encode_latent_cond (self , image ):
310- latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (image ))
311- return latent_image
317+ def encode_latent_cond (self , control_image , inpaint_image = None ):
318+ latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (control_image ))
319+ if self .model_patch .model .additional_in_dim > 0 :
320+ if self .mask is None :
321+ mask_ = torch .zeros_like (latent_image )[:, :1 ]
322+ else :
323+ mask_ = comfy .utils .common_upscale (self .mask .mean (dim = 1 , keepdim = True ), latent_image .shape [- 1 ], latent_image .shape [- 2 ], "bilinear" , "none" )
324+ if inpaint_image is None :
325+ inpaint_image = torch .ones_like (control_image ) * 0.5
326+
327+ inpaint_image_latent = comfy .latent_formats .Flux ().process_in (self .vae .encode (inpaint_image ))
328+
329+ return torch .cat ([latent_image , mask_ , inpaint_image_latent ], dim = 1 )
330+ else :
331+ return latent_image
312332
313333 def __call__ (self , kwargs ):
314334 x = kwargs .get ("x" )
315335 img = kwargs .get ("img" )
336+ img_input = kwargs .get ("img_input" )
316337 txt = kwargs .get ("txt" )
317338 pe = kwargs .get ("pe" )
318339 vec = kwargs .get ("vec" )
319340 block_index = kwargs .get ("block_index" )
341+ block_type = kwargs .get ("block_type" , "" )
320342 spacial_compression = self .vae .spacial_compression_encode ()
321343 if self .encoded_image is None or self .encoded_image_size != (x .shape [- 2 ] * spacial_compression , x .shape [- 1 ] * spacial_compression ):
322344 image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" )
345+ inpaint_scaled = None
346+ if self .inpaint_image is not None :
347+ inpaint_scaled = comfy .utils .common_upscale (self .inpaint_image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" ).movedim (1 , - 1 )
323348 loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
324- self .encoded_image = self .encode_latent_cond (image_scaled .movedim (1 , - 1 ))
349+ self .encoded_image = self .encode_latent_cond (image_scaled .movedim (1 , - 1 ), inpaint_scaled )
325350 self .encoded_image_size = (image_scaled .shape [- 2 ], image_scaled .shape [- 1 ])
326351 comfy .model_management .load_models_gpu (loaded_models )
327352
328- cnet_index = (block_index // 5 )
329- cnet_index_float = (block_index / 5 )
353+ cnet_blocks = self .model_patch .model .n_control_layers
354+ div = round (30 / cnet_blocks )
355+
356+ cnet_index = (block_index // div )
357+ cnet_index_float = (block_index / div )
330358
331359 kwargs .pop ("img" ) # we do ops in place
332360 kwargs .pop ("txt" )
333361
334- cnet_blocks = self .model_patch .model .n_control_layers
335362 if cnet_index_float > (cnet_blocks - 1 ):
336363 self .temp_data = None
337364 return kwargs
338365
339366 if self .temp_data is None or self .temp_data [0 ] > cnet_index :
340- self .temp_data = (- 1 , (None , self .model_patch .model (txt , self .encoded_image .to (img .dtype ), pe , vec )))
367+ if block_type == "noise_refiner" :
368+ self .temp_data = (- 3 , (None , self .model_patch .model (txt , self .encoded_image .to (img .dtype ), pe , vec )))
369+ else :
370+ self .temp_data = (- 1 , (None , self .model_patch .model (txt , self .encoded_image .to (img .dtype ), pe , vec )))
341371
342- while self . temp_data [ 0 ] < cnet_index and ( self . temp_data [ 0 ] + 1 ) < cnet_blocks :
372+ if block_type == "noise_refiner" :
343373 next_layer = self .temp_data [0 ] + 1
344- self .temp_data = (next_layer , self .model_patch .model .forward_control_block (next_layer , self .temp_data [1 ][1 ], img [:, :self .temp_data [1 ][1 ].shape [1 ]], None , pe , vec ))
374+ self .temp_data = (next_layer , self .model_patch .model .forward_noise_refiner_block (block_index , self .temp_data [1 ][1 ], img_input [:, :self .temp_data [1 ][1 ].shape [1 ]], None , pe , vec ))
375+ if self .temp_data [1 ][0 ] is not None :
376+ img [:, :self .temp_data [1 ][0 ].shape [1 ]] += (self .temp_data [1 ][0 ] * self .strength )
377+ else :
378+ while self .temp_data [0 ] < cnet_index and (self .temp_data [0 ] + 1 ) < cnet_blocks :
379+ next_layer = self .temp_data [0 ] + 1
380+ self .temp_data = (next_layer , self .model_patch .model .forward_control_block (next_layer , self .temp_data [1 ][1 ], img_input [:, :self .temp_data [1 ][1 ].shape [1 ]], None , pe , vec ))
345381
346- if cnet_index_float == self .temp_data [0 ]:
347- img [:, :self .temp_data [1 ][0 ].shape [1 ]] += (self .temp_data [1 ][0 ] * self .strength )
348- if cnet_blocks == self .temp_data [0 ] + 1 :
349- self .temp_data = None
382+ if cnet_index_float == self .temp_data [0 ]:
383+ img [:, :self .temp_data [1 ][0 ].shape [1 ]] += (self .temp_data [1 ][0 ] * self .strength )
384+ if cnet_blocks == self .temp_data [0 ] + 1 :
385+ self .temp_data = None
350386
351387 return kwargs
352388
@@ -386,7 +422,9 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
386422 mask = 1.0 - mask
387423
388424 if isinstance (model_patch .model , comfy .ldm .lumina .controlnet .ZImage_Control ):
389- model_patched .set_model_double_block_patch (ZImageControlPatch (model_patch , vae , image , strength ))
425+ patch = ZImageControlPatch (model_patch , vae , image , strength , mask = mask )
426+ model_patched .set_model_noise_refiner_patch (patch )
427+ model_patched .set_model_double_block_patch (patch )
390428 else :
391429 model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
392430 return (model_patched ,)
0 commit comments