1+ from enum import Enum
12import re
23import os
34import torch
2728from diffsynth_engine .utils .download import fetch_model
2829from diffsynth_engine .utils .platform import empty_cache
2930
31+ from einops import rearrange
32+
3033logger = logging .get_logger (__name__ )
3134
3235
@@ -244,11 +247,25 @@ def accumulate(result, new_item):
244247ImageType = Union [Image .Image , torch .Tensor , List [Image .Image ], List [torch .Tensor ]]
245248
246249
250+ class ControlType (Enum ):
251+ normal = "normal"
252+ bfl_control = "bfl_control"
253+ bfl_fill = "bfl_fill"
254+
255+ def get_in_channel (self ):
256+ if self == ControlType .normal :
257+ return 64
258+ elif self == ControlType .bfl_control :
259+ return 128
260+ elif self == ControlType .bfl_fill :
261+ return 384
262+
263+
247264@dataclass
248265class ControlNetParams :
249- model : nn .Module
250266 scale : float
251267 image : ImageType
268+ model : Optional [nn .Module ] = None
252269 mask : Optional [ImageType ] = None
253270 control_start : float = 0
254271 control_end : float = 1
@@ -287,6 +304,7 @@ def __init__(
287304 vae_tiled : bool = False ,
288305 vae_tile_size : int = 256 ,
289306 vae_tile_stride : int = 256 ,
307+ control_type : ControlType = ControlType .normal ,
290308 device : str = "cuda:0" ,
291309 dtype : torch .dtype = torch .bfloat16 ,
292310 ):
@@ -312,6 +330,7 @@ def __init__(
312330 self .batch_cfg = batch_cfg
313331 self .ip_adapter = None
314332 self .redux = None
333+ self .control_type = control_type
315334 self .model_names = [
316335 "text_encoder_1" ,
317336 "text_encoder_2" ,
@@ -324,6 +343,7 @@ def __init__(
324343 def from_pretrained (
325344 cls ,
326345 model_path_or_config : str | os .PathLike | FluxModelConfig ,
346+ control_type : ControlType = ControlType .normal ,
327347 device : str = "cuda:0" ,
328348 dtype : torch .dtype = torch .bfloat16 ,
329349 offload_mode : str | None = None ,
@@ -364,7 +384,11 @@ def from_pretrained(
364384 tokenizer_2 = T5TokenizerFast .from_pretrained (FLUX_TOKENIZER_2_CONF_PATH )
365385 with LoRAContext ():
366386 dit = FluxDiT .from_state_dict (
367- dit_state_dict , device = init_device , dtype = model_config .dit_dtype , attn_impl = model_config .dit_attn_impl
387+ dit_state_dict ,
388+ device = init_device ,
389+ dtype = model_config .dit_dtype ,
390+ in_channel = control_type .get_in_channel (),
391+ attn_impl = model_config .dit_attn_impl ,
368392 )
369393 if load_text_encoder :
370394 text_encoder_1 = FluxTextEncoder1 .from_state_dict (
@@ -386,6 +410,7 @@ def from_pretrained(
386410 vae_decoder = vae_decoder ,
387411 vae_encoder = vae_encoder ,
388412 load_text_encoder = load_text_encoder ,
413+ control_type = control_type ,
389414 device = device ,
390415 dtype = dtype ,
391416 )
@@ -535,6 +560,12 @@ def predict_noise(
535560 current_step : int ,
536561 total_step : int ,
537562 ):
563+ if self .control_type != ControlType .normal :
564+ controlnet_param = controlnet_params [0 ]
565+ latents = torch .cat ((latents , controlnet_param .image * controlnet_param .scale ), dim = 1 )
566+ latents = latents .to (self .dtype )
567+ controlnet_params = []
568+
538569 double_block_output , single_block_output = self .predict_multicontrolnet (
539570 latents = latents ,
540571 timestep = timestep ,
@@ -547,7 +578,9 @@ def predict_noise(
547578 current_step = current_step ,
548579 total_step = total_step ,
549580 )
581+
550582 self .load_models_to_device (["dit" ])
583+
551584 noise_pred = self .dit (
552585 hidden_states = latents ,
553586 timestep = timestep ,
@@ -600,16 +633,28 @@ def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, he
600633 image = self .preprocess_image (image ).to (device = self .device , dtype = self .dtype )
601634 latent = self .encode_image (image )
602635 else :
603- image = image .resize ((width , height ))
604- mask = mask .resize ((width , height ))
605- image = self .preprocess_image (image ).to (device = self .device , dtype = self .dtype )
606- mask = self .preprocess_mask (mask ).to (device = self .device , dtype = self .dtype )
607- masked_image = image .clone ()
608- masked_image [(mask > 0.5 ).repeat (1 , 3 , 1 , 1 )] = - 1
609- latent = self .encode_image (masked_image )
610- mask = torch .nn .functional .interpolate (mask , size = (latent .shape [2 ], latent .shape [3 ]))
611- mask = 1 - mask
612- latent = torch .cat ([latent , mask ], dim = 1 )
636+ if self .control_type == ControlType .normal :
637+ image = image .resize ((width , height ))
638+ mask = mask .resize ((width , height ))
639+ image = self .preprocess_image (image ).to (device = self .device , dtype = self .dtype )
640+ mask = self .preprocess_mask (mask ).to (device = self .device , dtype = self .dtype )
641+ masked_image = image .clone ()
642+ masked_image [(mask > 0.5 ).repeat (1 , 3 , 1 , 1 )] = - 1
643+ latent = self .encode_image (masked_image )
644+ mask = torch .nn .functional .interpolate (mask , size = (latent .shape [2 ], latent .shape [3 ]))
645+ mask = 1 - mask
646+ latent = torch .cat ([latent , mask ], dim = 1 )
647+ elif self .control_type == ControlType .bfl_fill :
648+ image = image .resize ((width , height ))
649+ mask = mask .resize ((width , height ))
650+ image = self .preprocess_image (image ).to (device = self .device , dtype = self .dtype )
651+ mask = self .preprocess_mask (mask ).to (device = self .device , dtype = self .dtype )
652+ image = image * (1 - mask )
653+ image = self .encode_image (image )
654+ mask = rearrange (mask , "b 1 (h ph) (w pw) -> b (ph pw) h w" , ph = 8 , pw = 8 )
655+ latent = torch .cat ((image , mask ), dim = 1 )
656+ else :
657+ raise ValueError (f"Unsupported mask latent prepare for controlnet type: { self .control_type } " )
613658 return latent
614659
615660 def prepare_controlnet_params (self , controlnet_params : List [ControlNetParams ], h , w ):
@@ -706,6 +751,9 @@ def __call__(
706751 controlnet_params : List [ControlNetParams ] | ControlNetParams = [],
707752 progress_callback : Optional [Callable ] = None , # def progress_callback(current, total, status)
708753 ):
754+ if self .control_type != ControlType .normal :
755+ assert controlnet_params and len (controlnet_params ) == 1 , "bfl_controlnet must have one controlnet"
756+
709757 if input_image is not None :
710758 width , height = input_image .size
711759 if not isinstance (controlnet_params , list ):
0 commit comments