3535import comfy .cldm .mmdit
3636import comfy .ldm .hydit .controlnet
3737import comfy .ldm .flux .controlnet
38-
38+ import comfy . cldm . dit_embedder
3939
4040def broadcast_image_to (tensor , target_batch_size , batched_number ):
4141 current_batch_size = tensor .shape [0 ]
@@ -78,6 +78,7 @@ def __init__(self):
7878 self .concat_mask = False
7979 self .extra_concat_orig = []
8080 self .extra_concat = None
81+ self .preprocess_image = lambda a : a
8182
8283 def set_cond_hint (self , cond_hint , strength = 1.0 , timestep_percent_range = (0.0 , 1.0 ), vae = None , extra_concat = []):
8384 self .cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ def copy_to(self, c):
129130 c .strength_type = self .strength_type
130131 c .concat_mask = self .concat_mask
131132 c .extra_concat_orig = self .extra_concat_orig .copy ()
133+ c .preprocess_image = self .preprocess_image
132134
133135 def inference_memory_requirements (self , dtype ):
134136 if self .previous_controlnet is not None :
@@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None):
181183
182184
183185class ControlNet (ControlBase ):
184- def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False ):
186+ def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False , preprocess_image = lambda a : a ):
185187 super ().__init__ ()
186188 self .control_model = control_model
187189 self .load_device = load_device
@@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
196198 self .extra_conds += extra_conds
197199 self .strength_type = strength_type
198200 self .concat_mask = concat_mask
201+ self .preprocess_image = preprocess_image
199202
200203 def get_control (self , x_noisy , t , cond , batched_number ):
201204 control_prev = None
@@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
224227 if self .latent_format is not None :
225228 raise ValueError ("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it." )
226229 self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , self .upscale_algorithm , "center" )
230+ self .cond_hint = self .preprocess_image (self .cond_hint )
227231 if self .vae is not None :
228232 loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
229233 self .cond_hint = self .vae .encode (self .cond_hint .movedim (1 , - 1 ))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
427431 logging .debug ("unexpected controlnet keys: {}" .format (unexpected ))
428432 return control_model
429433
434+
430435def load_controlnet_mmdit (sd , model_options = {}):
431436 new_sd = comfy .model_detection .convert_diffusers_mmdit (sd , "" )
432437 model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd , model_options = model_options )
@@ -448,6 +453,82 @@ def load_controlnet_mmdit(sd, model_options={}):
448453 return control
449454
450455
456+ class ControlNetSD35 (ControlNet ):
457+ def pre_run (self , model , percent_to_timestep_function ):
458+ if self .control_model .double_y_emb :
459+ missing , unexpected = self .control_model .orig_y_embedder .load_state_dict (model .diffusion_model .y_embedder .state_dict (), strict = False )
460+ else :
461+ missing , unexpected = self .control_model .x_embedder .load_state_dict (model .diffusion_model .x_embedder .state_dict (), strict = False )
462+ super ().pre_run (model , percent_to_timestep_function )
463+
464+ def copy (self ):
465+ c = ControlNetSD35 (None , global_average_pooling = self .global_average_pooling , load_device = self .load_device , manual_cast_dtype = self .manual_cast_dtype )
466+ c .control_model = self .control_model
467+ c .control_model_wrapped = self .control_model_wrapped
468+ self .copy_to (c )
469+ return c
470+
471+ def load_controlnet_sd35 (sd , model_options = {}):
472+ control_type = - 1
473+ if "control_type" in sd :
474+ control_type = round (sd .pop ("control_type" ).item ())
475+
476+ # blur_cnet = control_type == 0
477+ canny_cnet = control_type == 1
478+ depth_cnet = control_type == 2
479+
480+ new_sd = {}
481+ for k in comfy .utils .MMDIT_MAP_BASIC :
482+ if k [1 ] in sd :
483+ new_sd [k [0 ]] = sd .pop (k [1 ])
484+ for k in sd :
485+ new_sd [k ] = sd [k ]
486+ sd = new_sd
487+
488+ y_emb_shape = sd ["y_embedder.mlp.0.weight" ].shape
489+ depth = y_emb_shape [0 ] // 64
490+ hidden_size = 64 * depth
491+ num_heads = depth
492+ head_dim = hidden_size // num_heads
493+ num_blocks = comfy .model_detection .count_blocks (new_sd , 'transformer_blocks.{}.' )
494+
495+ load_device = comfy .model_management .get_torch_device ()
496+ offload_device = comfy .model_management .unet_offload_device ()
497+ unet_dtype = comfy .model_management .unet_dtype (model_params = - 1 )
498+
499+ manual_cast_dtype = comfy .model_management .unet_manual_cast (unet_dtype , load_device )
500+
501+ operations = model_options .get ("custom_operations" , None )
502+ if operations is None :
503+ operations = comfy .ops .pick_operations (unet_dtype , manual_cast_dtype , disable_fast_fp8 = True )
504+
505+ control_model = comfy .cldm .dit_embedder .ControlNetEmbedder (img_size = None ,
506+ patch_size = 2 ,
507+ in_chans = 16 ,
508+ num_layers = num_blocks ,
509+ main_model_double = depth ,
510+ double_y_emb = y_emb_shape [0 ] == y_emb_shape [1 ],
511+ attention_head_dim = head_dim ,
512+ num_attention_heads = num_heads ,
513+ adm_in_channels = 2048 ,
514+ device = offload_device ,
515+ dtype = unet_dtype ,
516+ operations = operations )
517+
518+ control_model = controlnet_load_state_dict (control_model , sd )
519+
520+ latent_format = comfy .latent_formats .SD3 ()
521+ preprocess_image = lambda a : a
522+ if canny_cnet :
523+ preprocess_image = lambda a : (a * 255 * 0.5 + 0.5 )
524+ elif depth_cnet :
525+ preprocess_image = lambda a : 1.0 - a
526+
527+ control = ControlNetSD35 (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , preprocess_image = preprocess_image )
528+ return control
529+
530+
531+
451532def load_controlnet_hunyuandit (controlnet_data , model_options = {}):
452533 model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (controlnet_data , model_options = model_options )
453534
@@ -560,7 +641,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
560641 if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data :
561642 return load_controlnet_flux_xlabs_mistoline (controlnet_data , model_options = model_options )
562643 elif "pos_embed_input.proj.weight" in controlnet_data :
563- return load_controlnet_mmdit (controlnet_data , model_options = model_options ) #SD3 diffusers controlnet
644+ if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data :
645+ return load_controlnet_sd35 (controlnet_data , model_options = model_options ) #Stability sd3.5 format
646+ else :
647+ return load_controlnet_mmdit (controlnet_data , model_options = model_options ) #SD3 diffusers controlnet
564648 elif "controlnet_x_embedder.weight" in controlnet_data :
565649 return load_controlnet_flux_instantx (controlnet_data , model_options = model_options )
566650 elif "controlnet_blocks.0.linear.weight" in controlnet_data : #mistoline flux
0 commit comments