44import numpy as np
55from einops import repeat
66from dataclasses import dataclass
7- from typing import Callable , Dict , Optional
7+ from typing import Callable , Dict , Optional , List
88from tqdm import tqdm
99from PIL import Image , ImageOps
1010
1111from diffsynth_engine .models .base import split_suffix
1212from diffsynth_engine .models .basic .lora import LoRAContext
1313from diffsynth_engine .models .sd import SDTextEncoder , SDVAEDecoder , SDVAEEncoder , SDUNet , sd_unet_config
1414from diffsynth_engine .pipelines import BasePipeline , LoRAStateDictConverter
15+ from diffsynth_engine .pipelines .controlnet_helper import ControlNetParams , accumulate
1516from diffsynth_engine .tokenizers import CLIPTokenizer
1617from diffsynth_engine .algorithm .noise_scheduler import ScaledLinearScheduler
1718from diffsynth_engine .algorithm .sampler import EulerSampler
@@ -259,37 +260,100 @@ def encode_prompt(self, prompt, clip_skip):
259260 prompt_emb = self .text_encoder (input_ids , clip_skip = clip_skip )
260261 return prompt_emb
261262
263+ def preprocess_control_image (self , image : Image .Image , mode = "RGB" ) -> torch .Tensor :
264+ image = image .convert (mode )
265+ image_array = np .array (image , dtype = np .float32 )
266+ if len (image_array .shape ) == 2 :
267+ image_array = image_array [:, :, np .newaxis ]
268+ image = torch .Tensor (image_array / 255 ).permute (2 , 0 , 1 ).unsqueeze (0 )
269+ return image
270+
271+ def prepare_controlnet_params (self , controlnet_params : List [ControlNetParams ], h , w ):
272+ results = []
273+ for param in controlnet_params :
274+ condition = self .preprocess_control_image (param .image ).to (device = self .device , dtype = self .dtype )
275+ results .append (
276+ ControlNetParams (
277+ model = param .model ,
278+ scale = param .scale ,
279+ image = condition ,
280+ )
281+ )
282+ return results
283+
284+ def predict_multicontrolnet (
285+ self ,
286+ latents : torch .Tensor ,
287+ timestep : torch .Tensor ,
288+ prompt_emb : torch .Tensor ,
289+ controlnet_params : List [ControlNetParams ],
290+ current_step : int ,
291+ total_step : int ,
292+ ):
293+ controlnet_res_stack = None
294+ if len (controlnet_params ) > 0 :
295+ self .load_models_to_device ([])
296+ for param in controlnet_params :
297+ current_scale = param .scale
298+ if not (
299+ current_step >= param .control_start * total_step and current_step <= param .control_end * total_step
300+ ):
301+ # if current_step is not in the control range
302+ # skip thie controlnet
303+ continue
304+ if self .offload_mode is not None :
305+ empty_cache ()
306+ param .model .to (self .device )
307+ controlnet_res = param .model (
308+ latents ,
309+ timestep ,
310+ prompt_emb ,
311+ param .image
312+ )
313+ controlnet_res = [res * current_scale for res in controlnet_res ]
314+ if self .offload_mode is not None :
315+ empty_cache ()
316+ param .model .to ("cpu" )
317+ controlnet_res_stack = accumulate (controlnet_res_stack , controlnet_res )
318+ return controlnet_res_stack
319+
262320 def predict_noise_with_cfg (
263321 self ,
264322 latents : torch .Tensor ,
265323 timestep : torch .Tensor ,
266324 positive_prompt_emb : torch .Tensor ,
267325 negative_prompt_emb : torch .Tensor ,
326+ controlnet_params : List [ControlNetParams ],
327+ current_step : int ,
328+ total_step : int ,
268329 cfg_scale : float ,
269330 batch_cfg : bool = True ,
270331 ):
271332 if cfg_scale <= 1.0 :
272- return self .predict_noise (latents , timestep , positive_prompt_emb )
333+ return self .predict_noise (latents , timestep , positive_prompt_emb , controlnet_params , current_step , total_step )
273334 if not batch_cfg :
274335 # cfg by predict noise one by one
275- positive_noise_pred = self .predict_noise (latents , timestep , positive_prompt_emb )
276- negative_noise_pred = self .predict_noise (latents , timestep , negative_prompt_emb )
336+ positive_noise_pred = self .predict_noise (latents , timestep , positive_prompt_emb , controlnet_params , current_step , total_step )
337+ negative_noise_pred = self .predict_noise (latents , timestep , negative_prompt_emb , controlnet_params , current_step , total_step )
277338 noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred )
278339 return noise_pred
279340 else :
280341 # cfg by predict noise in one batch
281342 prompt_emb = torch .cat ([positive_prompt_emb , negative_prompt_emb ], dim = 0 )
282343 latents = torch .cat ([latents , latents ], dim = 0 )
283344 timestep = torch .cat ([timestep , timestep ], dim = 0 )
284- positive_noise_pred , negative_noise_pred = self .predict_noise (latents , timestep , prompt_emb ).chunk (2 )
345+ positive_noise_pred , negative_noise_pred = self .predict_noise (latents , timestep , prompt_emb , controlnet_params , current_step , total_step ).chunk (2 )
285346 noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred )
286347 return noise_pred
287348
288- def predict_noise (self , latents , timestep , prompt_emb ):
349+ def predict_noise (self , latents , timestep , prompt_emb , controlnet_params , current_step , total_step ):
350+ controlnet_res_stack = self .predict_multicontrolnet (latents , timestep , prompt_emb , controlnet_params , current_step , total_step )
351+
289352 noise_pred = self .unet (
290353 x = latents ,
291354 timestep = timestep ,
292355 context = prompt_emb ,
356+ controlnet_res_stack = controlnet_res_stack ,
293357 device = self .device ,
294358 )
295359 return noise_pred
@@ -329,8 +393,12 @@ def __call__(
329393 width : int = 1024 ,
330394 num_inference_steps : int = 20 ,
331395 seed : int | None = None ,
396+ controlnet_params : List [ControlNetParams ] | ControlNetParams = [],
332397 progress_callback : Optional [Callable ] = None , # def progress_callback(current, total, status)
333398 ):
399+ if not isinstance (controlnet_params , list ):
400+ controlnet_params = [controlnet_params ]
401+
334402 if input_image is not None :
335403 width , height = input_image .size
336404 self .validate_image_size (height , width , minimum = 64 , multiple_of = 8 )
@@ -345,6 +413,9 @@ def __call__(
345413 # Initialize sampler
346414 self .sampler .initialize (init_latents = init_latents , timesteps = timesteps , sigmas = sigmas , mask = mask )
347415
416+ # ControlNet
417+ controlnet_params = self .prepare_controlnet_params (controlnet_params , h = height , w = width )
418+
348419 # Encode prompts
349420 self .load_models_to_device (["text_encoder" ])
350421 positive_prompt_emb = self .encode_prompt (prompt , clip_skip = clip_skip )
@@ -361,6 +432,9 @@ def __call__(
361432 positive_prompt_emb = positive_prompt_emb ,
362433 negative_prompt_emb = negative_prompt_emb ,
363434 cfg_scale = cfg_scale ,
435+ controlnet_params = controlnet_params ,
436+ current_step = i ,
437+ total_step = len (timesteps ),
364438 batch_cfg = self .batch_cfg ,
365439 )
366440 # Denoise
0 commit comments