Skip to content

Commit d75f65a

Browse files
committed
support sd1.5 controlnet
1 parent 056f81c commit d75f65a

File tree

9 files changed

+756
-30
lines changed

9 files changed

+756
-30
lines changed

diffsynth_engine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ControlNetParams,
1111
)
1212
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
13+
from .models.sd import SDControlNet
1314
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
1415
from .utils.video import load_video, save_video
1516
from .tools import (
@@ -25,6 +26,7 @@
2526
"FluxControlNet",
2627
"FluxIPAdapter",
2728
"FluxRedux",
29+
"SDControlNet",
2830
"SDXLImagePipeline",
2931
"SDImagePipeline",
3032
"WanVideoPipeline",
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .sd_text_encoder import SDTextEncoder, config as sd_text_encoder_config
22
from .sd_unet import SDUNet, config as sd_unet_config
33
from .sd_vae import SDVAEDecoder, SDVAEEncoder
4+
from .sd_controlnet import SDControlNet
45

56
__all__ = [
67
"SDTextEncoder",
78
"SDUNet",
89
"SDVAEDecoder",
910
"SDVAEEncoder",
11+
"SDControlNet",
1012
"sd_text_encoder_config",
1113
"sd_unet_config",
1214
]

diffsynth_engine/models/sd/sd_controlnet.py

Lines changed: 597 additions & 0 deletions
Large diffs are not rendered by default.

diffsynth_engine/models/sd/sd_unet.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float16):
264264
self.conv_act = nn.SiLU()
265265
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1, device=device, dtype=dtype)
266266

267-
def forward(self, x, timestep, context, **kwargs):
267+
def forward(self, x, timestep, context, controlnet_res_stack=None, **kwargs):
268268
# 1. time
269269
time_emb = self.time_embedding(timestep, dtype=x.dtype)
270270

@@ -273,10 +273,18 @@ def forward(self, x, timestep, context, **kwargs):
273273
text_emb = context
274274
res_stack = [hidden_states]
275275

276+
controlnet_insert_block_id = 30
277+
276278
# 3. blocks
277279
for i, block in enumerate(self.blocks):
280+
# 3.1 UNet
278281
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
279282

283+
# 3.2 Controlnet
284+
if i == controlnet_insert_block_id and controlnet_res_stack is not None:
285+
hidden_states += controlnet_res_stack.pop()
286+
res_stack = [res + controlnet_res for res, controlnet_res in zip(res_stack, controlnet_res_stack)]
287+
280288
# 4. output
281289
hidden_states = self.conv_norm_out(hidden_states)
282290
hidden_states = self.conv_act(hidden_states)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Dict, List, Tuple, Union, Optional
4+
from PIL import Image
5+
from dataclasses import dataclass
6+
7+
ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
8+
9+
@dataclass
10+
class ControlNetParams:
11+
scale: float
12+
image: ImageType
13+
model: Optional[nn.Module] = None
14+
mask: Optional[ImageType] = None
15+
control_start: float = 0
16+
control_end: float = 1
17+
18+
def accumulate(result, new_item):
19+
if result is None:
20+
return new_item
21+
for i, item in enumerate(new_item):
22+
result[i] += item
23+
return result

diffsynth_engine/pipelines/flux_image.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from diffsynth_engine.models.basic.lora import LoRAContext
2424
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
25+
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
2526
from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
2627
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
2728
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
@@ -415,17 +416,6 @@ def calculate_shift(
415416
return mu
416417

417418

418-
def accumulate(result, new_item):
419-
if result is None:
420-
return new_item
421-
for i, item in enumerate(new_item):
422-
result[i] += item
423-
return result
424-
425-
426-
ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
427-
428-
429419
class ControlType(Enum):
430420
normal = "normal"
431421
bfl_control = "bfl_control"
@@ -439,17 +429,6 @@ def get_in_channel(self):
439429
elif self == ControlType.bfl_fill:
440430
return 384
441431

442-
443-
@dataclass
444-
class ControlNetParams:
445-
scale: float
446-
image: ImageType
447-
model: Optional[nn.Module] = None
448-
mask: Optional[ImageType] = None
449-
control_start: float = 0
450-
control_end: float = 1
451-
452-
453432
@dataclass
454433
class FluxModelConfig:
455434
dit_path: str | os.PathLike

diffsynth_engine/pipelines/sd_image.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import numpy as np
55
from einops import repeat
66
from dataclasses import dataclass
7-
from typing import Callable, Dict, Optional
7+
from typing import Callable, Dict, Optional, List
88
from tqdm import tqdm
99
from PIL import Image, ImageOps
1010

1111
from diffsynth_engine.models.base import split_suffix
1212
from diffsynth_engine.models.basic.lora import LoRAContext
1313
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
1414
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
15+
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
1516
from diffsynth_engine.tokenizers import CLIPTokenizer
1617
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
1718
from diffsynth_engine.algorithm.sampler import EulerSampler
@@ -259,37 +260,100 @@ def encode_prompt(self, prompt, clip_skip):
259260
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
260261
return prompt_emb
261262

263+
def preprocess_control_image(self, image: Image.Image, mode="RGB") -> torch.Tensor:
264+
image = image.convert(mode)
265+
image_array = np.array(image, dtype=np.float32)
266+
if len(image_array.shape) == 2:
267+
image_array = image_array[:, :, np.newaxis]
268+
image = torch.Tensor(image_array / 255).permute(2, 0, 1).unsqueeze(0)
269+
return image
270+
271+
def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
272+
results = []
273+
for param in controlnet_params:
274+
condition = self.preprocess_control_image(param.image).to(device=self.device, dtype=self.dtype)
275+
results.append(
276+
ControlNetParams(
277+
model=param.model,
278+
scale=param.scale,
279+
image=condition,
280+
)
281+
)
282+
return results
283+
284+
def predict_multicontrolnet(
285+
self,
286+
latents: torch.Tensor,
287+
timestep: torch.Tensor,
288+
prompt_emb: torch.Tensor,
289+
controlnet_params: List[ControlNetParams],
290+
current_step: int,
291+
total_step: int,
292+
):
293+
controlnet_res_stack = None
294+
if len(controlnet_params) > 0:
295+
self.load_models_to_device([])
296+
for param in controlnet_params:
297+
current_scale = param.scale
298+
if not (
299+
current_step >= param.control_start * total_step and current_step <= param.control_end * total_step
300+
):
301+
# if current_step is not in the control range
302+
# skip thie controlnet
303+
continue
304+
if self.offload_mode is not None:
305+
empty_cache()
306+
param.model.to(self.device)
307+
controlnet_res = param.model(
308+
latents,
309+
timestep,
310+
prompt_emb,
311+
param.image
312+
)
313+
controlnet_res = [res * current_scale for res in controlnet_res]
314+
if self.offload_mode is not None:
315+
empty_cache()
316+
param.model.to("cpu")
317+
controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res)
318+
return controlnet_res_stack
319+
262320
def predict_noise_with_cfg(
263321
self,
264322
latents: torch.Tensor,
265323
timestep: torch.Tensor,
266324
positive_prompt_emb: torch.Tensor,
267325
negative_prompt_emb: torch.Tensor,
326+
controlnet_params: List[ControlNetParams],
327+
current_step: int,
328+
total_step: int,
268329
cfg_scale: float,
269330
batch_cfg: bool = True,
270331
):
271332
if cfg_scale <= 1.0:
272-
return self.predict_noise(latents, timestep, positive_prompt_emb)
333+
return self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step)
273334
if not batch_cfg:
274335
# cfg by predict noise one by one
275-
positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb)
276-
negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb)
336+
positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step)
337+
negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step)
277338
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
278339
return noise_pred
279340
else:
280341
# cfg by predict noise in one batch
281342
prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0)
282343
latents = torch.cat([latents, latents], dim=0)
283344
timestep = torch.cat([timestep, timestep], dim=0)
284-
positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb).chunk(2)
345+
positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb, controlnet_params, current_step, total_step).chunk(2)
285346
noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
286347
return noise_pred
287348

288-
def predict_noise(self, latents, timestep, prompt_emb):
349+
def predict_noise(self, latents, timestep, prompt_emb, controlnet_params, current_step, total_step):
350+
controlnet_res_stack = self.predict_multicontrolnet(latents, timestep, prompt_emb, controlnet_params, current_step, total_step)
351+
289352
noise_pred = self.unet(
290353
x=latents,
291354
timestep=timestep,
292355
context=prompt_emb,
356+
controlnet_res_stack=controlnet_res_stack,
293357
device=self.device,
294358
)
295359
return noise_pred
@@ -329,8 +393,12 @@ def __call__(
329393
width: int = 1024,
330394
num_inference_steps: int = 20,
331395
seed: int | None = None,
396+
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
332397
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
333398
):
399+
if not isinstance(controlnet_params, list):
400+
controlnet_params = [controlnet_params]
401+
334402
if input_image is not None:
335403
width, height = input_image.size
336404
self.validate_image_size(height, width, minimum=64, multiple_of=8)
@@ -345,6 +413,9 @@ def __call__(
345413
# Initialize sampler
346414
self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas, mask=mask)
347415

416+
# ControlNet
417+
controlnet_params = self.prepare_controlnet_params(controlnet_params, h=height, w=width)
418+
348419
# Encode prompts
349420
self.load_models_to_device(["text_encoder"])
350421
positive_prompt_emb = self.encode_prompt(prompt, clip_skip=clip_skip)
@@ -361,6 +432,9 @@ def __call__(
361432
positive_prompt_emb=positive_prompt_emb,
362433
negative_prompt_emb=negative_prompt_emb,
363434
cfg_scale=cfg_scale,
435+
controlnet_params=controlnet_params,
436+
current_step=i,
437+
total_step=len(timesteps),
364438
batch_cfg=self.batch_cfg,
365439
)
366440
# Denoise

diffsynth_engine/utils/flag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
else:
2020
logger.info("Flash attention 2 is not available")
2121

22-
XFORMERS_AVAILABLE = importlib.util.find_spec("xformers") is not None
22+
XFORMERS_AVAILABLE = None # importlib.util.find_spec("xformers") is not None
2323
if XFORMERS_AVAILABLE:
2424
logger.info("XFormers is available")
2525
else:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
3+
from tests.common.test_case import ImageTestCase
4+
from diffsynth_engine import SDImagePipeline, SDControlNet, ControlNetParams, fetch_model
5+
import torch
6+
7+
8+
class TestSDControlNet(ImageTestCase):
9+
@classmethod
10+
def setUpClass(cls):
11+
model_path = fetch_model(
12+
"muse/v1-5-pruned-emaonly", revision="20240118200020", path="v1-5-pruned-emaonly.safetensors"
13+
)
14+
cls.pipe = SDImagePipeline.from_pretrained(model_path)
15+
16+
def test_canny(self):
17+
canny_image = self.get_input_image("canny.png")
18+
controlnet = SDControlNet.from_pretrained(
19+
fetch_model("lllyasviel/sd-controlnet-canny", path="diffusion_pytorch_model.safetensors"),
20+
device="cuda:0",
21+
dtype=torch.float16,
22+
)
23+
output_image = self.pipe(
24+
prompt="A young girl stands gracefully at the edge of a serene beach, her long, flowing hair gently tousled by the sea breeze. She wears a soft, pastel-colored dress that complements the tranquil blues and greens of the coastal scenery. The golden hues of the setting sun cast a warm glow on her face, highlighting her serene expression. The background features a vast, azure ocean with gentle waves lapping at the shore, surrounded by distant cliffs and a clear, cloudless sky. The composition emphasizes the girl's serene presence amidst the natural beauty, with a balanced blend of warm and cool tones.",
25+
height=canny_image.height,
26+
width=canny_image.width,
27+
num_inference_steps=30,
28+
seed=42,
29+
controlnet_params=ControlNetParams(
30+
model=controlnet,
31+
scale=1.0,
32+
control_end=1.0,
33+
image=canny_image,
34+
),
35+
)
36+
# TODO: replace image
37+
self.assertImageEqualAndSaveFailed(output_image, "flux/flux_union_pro_canny.png", threshold=0.7)
38+
39+
40+
if __name__ == "__main__":
41+
unittest.main()

0 commit comments

Comments
 (0)