Skip to content

Commit a304372

Browse files
committed
support sdxl controlnet union
1 parent d75f65a commit a304372

File tree

12 files changed

+445
-17
lines changed

12 files changed

+445
-17
lines changed

diffsynth_engine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
1313
from .models.sd import SDControlNet
14+
from .models.sdxl import SDXLControlNetUnion
1415
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
1516
from .utils.video import load_video, save_video
1617
from .tools import (
@@ -27,6 +28,7 @@
2728
"FluxIPAdapter",
2829
"FluxRedux",
2930
"SDControlNet",
31+
"SDXLControlNetUnion",
3032
"SDXLImagePipeline",
3133
"SDImagePipeline",
3234
"WanVideoPipeline",

diffsynth_engine/models/sd/sd_controlnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import json
21
import torch
32
import torch.nn as nn
43
from typing import Dict, Optional
54

6-
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter, split_suffix
5+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
76
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
87
from diffsynth_engine.models.utils import no_init_weights
98
from diffsynth_engine.models.basic.unet_helper import (
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2, config as sdxl_text_encoder_config
22
from .sdxl_unet import SDXLUNet, config as sdxl_unet_config
33
from .sdxl_vae import SDXLVAEDecoder, SDXLVAEEncoder
4+
from .sdxl_controlnet import SDXLControlNetUnion
45

56
__all__ = [
67
"SDXLTextEncoder",
78
"SDXLTextEncoder2",
89
"SDXLUNet",
910
"SDXLVAEDecoder",
1011
"SDXLVAEEncoder",
12+
"SDXLControlNetUnion",
1113
"sdxl_text_encoder_config",
1214
"sdxl_unet_config",
1315
]
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Optional, Dict
4+
from diffsynth_engine.models.basic.unet_helper import (
5+
ResnetBlock,
6+
AttentionBlock,
7+
PushBlock,
8+
DownSampler,
9+
PopBlock,
10+
UpSampler,
11+
)
12+
from diffsynth_engine.models.sd.sd_controlnet import ControlNetConditioningLayer
13+
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
14+
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings, TemporalTimesteps
15+
16+
from collections import OrderedDict
17+
18+
class QuickGELU(torch.nn.Module):
19+
20+
def forward(self, x: torch.Tensor):
21+
return x * torch.sigmoid(1.702 * x)
22+
23+
class ResidualAttentionBlock(torch.nn.Module):
24+
25+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
26+
super().__init__()
27+
28+
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
29+
self.ln_1 = torch.nn.LayerNorm(d_model)
30+
self.mlp = torch.nn.Sequential(OrderedDict([
31+
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
32+
("gelu", QuickGELU()),
33+
("c_proj", torch.nn.Linear(d_model * 4, d_model))
34+
]))
35+
self.ln_2 = torch.nn.LayerNorm(d_model)
36+
self.attn_mask = attn_mask
37+
38+
def attention(self, x: torch.Tensor):
39+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
40+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
41+
42+
def forward(self, x: torch.Tensor):
43+
x = x + self.attention(self.ln_1(x))
44+
x = x + self.mlp(self.ln_2(x))
45+
return x
46+
47+
48+
class SDXLControlNetUnionStateDictConverter(StateDictConverter):
49+
def __init__(self):
50+
super().__init__()
51+
52+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
53+
# architecture
54+
block_types = [
55+
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
56+
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
57+
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
58+
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
59+
]
60+
61+
# controlnet_rename_dict
62+
controlnet_rename_dict = {
63+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
64+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
65+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
66+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
67+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
68+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
69+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
70+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
71+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
72+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
73+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
74+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
75+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
76+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
77+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
78+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
79+
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
80+
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
81+
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
82+
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
83+
}
84+
85+
# Rename each parameter
86+
name_list = sorted([name for name in state_dict])
87+
rename_dict = {}
88+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
89+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
90+
for name in name_list:
91+
names = name.split(".")
92+
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
93+
pass
94+
elif name in controlnet_rename_dict:
95+
names = controlnet_rename_dict[name].split(".")
96+
elif names[0] == "controlnet_down_blocks":
97+
names[0] = "controlnet_blocks"
98+
elif names[0] == "controlnet_mid_block":
99+
names = ["controlnet_blocks", "9", names[-1]]
100+
elif names[0] == "time_embedding":
101+
names[1] = {"linear_1": "timestep_embedder.0", "linear_2": "timestep_embedder.2"}[names[1]]
102+
elif names[0] == "add_embedding":
103+
names[0] = "add_time_embedding"
104+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
105+
elif names[0] == "control_add_embedding":
106+
names[0] = "control_type_embedding"
107+
elif names[0] == "transformer_layes":
108+
names[0] = "controlnet_transformer"
109+
names.pop(1)
110+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
111+
if names[0] == "mid_block":
112+
names.insert(1, "0")
113+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
114+
block_type_with_id = ".".join(names[:4])
115+
if block_type_with_id != last_block_type_with_id[block_type]:
116+
block_id[block_type] += 1
117+
last_block_type_with_id[block_type] = block_type_with_id
118+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
119+
block_id[block_type] += 1
120+
block_type_with_id = ".".join(names[:4])
121+
names = ["blocks", str(block_id[block_type])] + names[4:]
122+
if "ff" in names:
123+
ff_index = names.index("ff")
124+
component = ".".join(names[ff_index:ff_index+3])
125+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
126+
names = names[:ff_index] + [component] + names[ff_index+3:]
127+
if "to_out" in names:
128+
names.pop(names.index("to_out") + 1)
129+
else:
130+
print(name, state_dict[name].shape)
131+
# raise ValueError(f"Unknown parameters: {name}")
132+
rename_dict[name] = ".".join(names)
133+
134+
# Convert state_dict
135+
state_dict_ = {}
136+
for name, param in state_dict.items():
137+
if name not in rename_dict:
138+
continue
139+
if ".proj_in." in name or ".proj_out." in name:
140+
param = param.squeeze()
141+
state_dict_[rename_dict[name]] = param
142+
return state_dict_
143+
144+
# TODO: check civitai
145+
def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
146+
return self._from_diffusers(state_dict)
147+
148+
149+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
150+
return self._from_diffusers(state_dict)
151+
152+
class SDXLControlNetUnion(PreTrainedModel):
153+
converter = SDXLControlNetUnionStateDictConverter()
154+
155+
def __init__(self,
156+
attn_impl: Optional[str] = None,
157+
device: str = "cuda:0",
158+
dtype: torch.dtype = torch.bfloat16,
159+
):
160+
super().__init__()
161+
self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
162+
163+
self.add_time_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
164+
self.add_time_embedding = torch.nn.Sequential(
165+
torch.nn.Linear(2816, 1280),
166+
torch.nn.SiLU(),
167+
torch.nn.Linear(1280, 1280)
168+
)
169+
self.control_type_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
170+
self.control_type_embedding = torch.nn.Sequential(
171+
torch.nn.Linear(256 * 8, 1280),
172+
torch.nn.SiLU(),
173+
torch.nn.Linear(1280, 1280)
174+
)
175+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
176+
177+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
178+
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
179+
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
180+
self.spatial_ch_projs = torch.nn.Linear(320, 320)
181+
182+
self.blocks = torch.nn.ModuleList([
183+
# DownBlock2D
184+
ResnetBlock(320, 320, 1280),
185+
PushBlock(),
186+
ResnetBlock(320, 320, 1280),
187+
PushBlock(),
188+
DownSampler(320),
189+
PushBlock(),
190+
# CrossAttnDownBlock2D
191+
ResnetBlock(320, 640, 1280),
192+
AttentionBlock(10, 64, 640, 2, 2048),
193+
PushBlock(),
194+
ResnetBlock(640, 640, 1280),
195+
AttentionBlock(10, 64, 640, 2, 2048),
196+
PushBlock(),
197+
DownSampler(640),
198+
PushBlock(),
199+
# CrossAttnDownBlock2D
200+
ResnetBlock(640, 1280, 1280),
201+
AttentionBlock(20, 64, 1280, 10, 2048),
202+
PushBlock(),
203+
ResnetBlock(1280, 1280, 1280),
204+
AttentionBlock(20, 64, 1280, 10, 2048),
205+
PushBlock(),
206+
# UNetMidBlock2DCrossAttn
207+
ResnetBlock(1280, 1280, 1280),
208+
AttentionBlock(20, 64, 1280, 10, 2048),
209+
ResnetBlock(1280, 1280, 1280),
210+
PushBlock()
211+
])
212+
213+
self.controlnet_blocks = torch.nn.ModuleList([
214+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
215+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
216+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
217+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
218+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
219+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
220+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
221+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
222+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
223+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
224+
])
225+
226+
# 0 -- openpose
227+
# 1 -- depth
228+
# 2 -- hed/pidi/scribble/ted
229+
# 3 -- canny/lineart/anime_lineart/mlsd
230+
# 4 -- normal
231+
# 5 -- segment
232+
# 6 -- tile
233+
# 7 -- repaint
234+
self.task_id = {
235+
"openpose": 0,
236+
"depth": 1,
237+
"softedge": 2,
238+
"canny": 3,
239+
"lineart": 3,
240+
"lineart_anime": 3,
241+
"tile": 6,
242+
"inpaint": 7
243+
}
244+
245+
246+
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
247+
controlnet_cond = self.controlnet_conv_in(conditioning)
248+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
249+
feat_seq = feat_seq + self.task_embedding[task_id]
250+
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
251+
x = self.controlnet_transformer(x)
252+
253+
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
254+
controlnet_cond_fuser = controlnet_cond + alpha
255+
256+
hidden_states = hidden_states + controlnet_cond_fuser
257+
return hidden_states
258+
259+
260+
def forward(
261+
self,
262+
sample, timestep, encoder_hidden_states,
263+
conditioning, processor_name, add_time_id, add_text_embeds,
264+
tiled=False, tile_size=64, tile_stride=32,
265+
**kwargs
266+
):
267+
task_id = self.task_id[processor_name]
268+
269+
# 1. time embedding
270+
t_emb = self.time_embedding(timestep, dtype=sample.dtype)
271+
time_embeds = self.add_time_proj(add_time_id)
272+
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
273+
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
274+
add_embeds = add_embeds.to(sample.dtype)
275+
add_embeds = self.add_time_embedding(add_embeds)
276+
277+
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
278+
control_type[:, task_id] = 1
279+
control_embeds = self.control_type_proj(control_type.flatten())
280+
control_embeds = control_embeds.reshape((sample.shape[0], -1))
281+
control_embeds = control_embeds.to(sample.dtype)
282+
control_embeds = self.control_type_embedding(control_embeds)
283+
time_emb = t_emb + add_embeds + control_embeds
284+
285+
# 2. pre-process
286+
height, width = sample.shape[2], sample.shape[3]
287+
hidden_states = self.conv_in(sample)
288+
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
289+
text_emb = encoder_hidden_states
290+
res_stack = [hidden_states]
291+
292+
# 3. blocks
293+
for i, block in enumerate(self.blocks):
294+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
295+
296+
# 4. ControlNet blocks
297+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
298+
299+
return controlnet_res_stack

diffsynth_engine/models/sdxl/sdxl_unet.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(
244244

245245
self.is_kolors = is_kolors
246246

247-
def forward(self, x, timestep, context, y, **kwargs):
247+
def forward(self, x, timestep, context, y, controlnet_res_stack=None, **kwargs):
248248
# 1. time embedding
249249
t_emb = self.time_embedding(timestep, dtype=x.dtype)
250250
## add embedding
@@ -257,14 +257,23 @@ def forward(self, x, timestep, context, y, **kwargs):
257257
text_emb = context if self.text_intermediate_proj is None else self.text_intermediate_proj(context)
258258
res_stack = [hidden_states]
259259

260+
controlnet_insert_block_id = 22
261+
260262
# 3. blocks
261263
for i, block in enumerate(self.blocks):
264+
# 3.1 UNet
262265
hidden_states, time_emb, text_emb, res_stack = block(
263266
hidden_states,
264267
time_emb,
265268
text_emb,
266269
res_stack,
267270
)
271+
272+
# 3.2 Controlnet
273+
if i == controlnet_insert_block_id and controlnet_res_stack is not None:
274+
hidden_states += controlnet_res_stack.pop()
275+
res_stack = [res + controlnet_res for res, controlnet_res in zip(res_stack, controlnet_res_stack)]
276+
268277

269278
# 4. output
270279
hidden_states = self.conv_norm_out(hidden_states)

diffsynth_engine/pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BasePipeline, LoRAStateDictConverter
2-
from .flux_image import FluxImagePipeline, FluxModelConfig, ControlNetParams
2+
from .controlnet_helper import ControlNetParams
3+
from .flux_image import FluxImagePipeline, FluxModelConfig
34
from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
45
from .sd_image import SDImagePipeline, SDModelConfig
56
from .wan_video import WanVideoPipeline, WanModelConfig

diffsynth_engine/pipelines/controlnet_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ControlNetParams:
1414
mask: Optional[ImageType] = None
1515
control_start: float = 0
1616
control_end: float = 1
17+
processor_name: Optional[str] = None # only used for sdxl controlnet union now
1718

1819
def accumulate(result, new_item):
1920
if result is None:

0 commit comments

Comments
 (0)