Skip to content

Commit da913d6

Browse files
committed
fix controlnet device
1 parent a304372 commit da913d6

File tree

3 files changed

+79
-80
lines changed

3 files changed

+79
-80
lines changed

diffsynth_engine/models/sd/sd_controlnet.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
)
1616

1717
class ControlNetConditioningLayer(nn.Module):
18-
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
18+
def __init__(self, channels = (3, 16, 32, 96, 256, 320), device = "cuda:0", dtype=torch.float16):
1919
super().__init__()
2020
self.blocks = torch.nn.ModuleList([])
21-
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
21+
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype))
2222
self.blocks.append(torch.nn.SiLU())
2323
for i in range(1, len(channels) - 2):
24-
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
24+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype))
2525
self.blocks.append(torch.nn.SiLU())
26-
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
26+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype))
2727
self.blocks.append(torch.nn.SiLU())
28-
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
28+
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype))
2929

3030
def forward(self, conditioning):
3131
for block in self.blocks:
@@ -496,64 +496,64 @@ def __init__(
496496
):
497497
super().__init__()
498498
self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype)
499-
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
499+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
500500

501-
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
501+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype)
502502

503503
self.blocks = torch.nn.ModuleList([
504504
# CrossAttnDownBlock2D
505-
ResnetBlock(320, 320, 1280),
506-
AttentionBlock(8, 40, 320, 1, 768),
505+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
506+
AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
507507
PushBlock(),
508-
ResnetBlock(320, 320, 1280),
509-
AttentionBlock(8, 40, 320, 1, 768),
508+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
509+
AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype),
510510
PushBlock(),
511-
DownSampler(320),
511+
DownSampler(320, device=device, dtype=dtype),
512512
PushBlock(),
513513
# CrossAttnDownBlock2D
514-
ResnetBlock(320, 640, 1280),
515-
AttentionBlock(8, 80, 640, 1, 768),
514+
ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
515+
AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
516516
PushBlock(),
517-
ResnetBlock(640, 640, 1280),
518-
AttentionBlock(8, 80, 640, 1, 768),
517+
ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
518+
AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype),
519519
PushBlock(),
520-
DownSampler(640),
520+
DownSampler(640, device=device, dtype=dtype),
521521
PushBlock(),
522522
# CrossAttnDownBlock2D
523-
ResnetBlock(640, 1280, 1280),
524-
AttentionBlock(8, 160, 1280, 1, 768),
523+
ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
524+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
525525
PushBlock(),
526-
ResnetBlock(1280, 1280, 1280),
527-
AttentionBlock(8, 160, 1280, 1, 768),
526+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
527+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
528528
PushBlock(),
529-
DownSampler(1280),
529+
DownSampler(1280, device=device, dtype=dtype),
530530
PushBlock(),
531531
# DownBlock2D
532-
ResnetBlock(1280, 1280, 1280),
532+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
533533
PushBlock(),
534-
ResnetBlock(1280, 1280, 1280),
534+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
535535
PushBlock(),
536536
# UNetMidBlock2DCrossAttn
537-
ResnetBlock(1280, 1280, 1280),
538-
AttentionBlock(8, 160, 1280, 1, 768),
539-
ResnetBlock(1280, 1280, 1280),
537+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
538+
AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype),
539+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
540540
PushBlock()
541541
])
542542

543543
self.controlnet_blocks = torch.nn.ModuleList([
544-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
545-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
546-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
547-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
548-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
549-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
550-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
551-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
552-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
553-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
554-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
555-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
556-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
544+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
545+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
546+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
547+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
548+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
549+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
550+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
551+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
552+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
553+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
554+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
555+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
556+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype),
557557
])
558558

559559
def forward(

diffsynth_engine/models/sdxl/sdxl_controlnet.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ def forward(self, x: torch.Tensor):
2222

2323
class ResidualAttentionBlock(torch.nn.Module):
2424

25-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
25+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, device="cuda:0", dtype=torch.float16):
2626
super().__init__()
2727

28-
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
29-
self.ln_1 = torch.nn.LayerNorm(d_model)
28+
self.attn = torch.nn.MultiheadAttention(d_model, n_head, device=device, dtype=dtype)
29+
self.ln_1 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
3030
self.mlp = torch.nn.Sequential(OrderedDict([
31-
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
31+
("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)),
3232
("gelu", QuickGELU()),
33-
("c_proj", torch.nn.Linear(d_model * 4, d_model))
33+
("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype))
3434
]))
35-
self.ln_2 = torch.nn.LayerNorm(d_model)
35+
self.ln_2 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype)
3636
self.attn_mask = attn_mask
3737

3838
def attention(self, x: torch.Tensor):
@@ -162,65 +162,65 @@ def __init__(self,
162162

163163
self.add_time_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
164164
self.add_time_embedding = torch.nn.Sequential(
165-
torch.nn.Linear(2816, 1280),
165+
torch.nn.Linear(2816, 1280, device=device, dtype=dtype),
166166
torch.nn.SiLU(),
167-
torch.nn.Linear(1280, 1280)
167+
torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
168168
)
169169
self.control_type_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype)
170170
self.control_type_embedding = torch.nn.Sequential(
171-
torch.nn.Linear(256 * 8, 1280),
171+
torch.nn.Linear(256 * 8, 1280, device=device, dtype=dtype),
172172
torch.nn.SiLU(),
173-
torch.nn.Linear(1280, 1280)
173+
torch.nn.Linear(1280, 1280, device=device, dtype=dtype)
174174
)
175-
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
175+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype)
176176

177-
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
178-
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
177+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype)
178+
self.controlnet_transformer = ResidualAttentionBlock(320, 8, device=device, dtype=dtype)
179179
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
180-
self.spatial_ch_projs = torch.nn.Linear(320, 320)
180+
self.spatial_ch_projs = torch.nn.Linear(320, 320, device=device, dtype=dtype)
181181

182182
self.blocks = torch.nn.ModuleList([
183183
# DownBlock2D
184-
ResnetBlock(320, 320, 1280),
184+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
185185
PushBlock(),
186-
ResnetBlock(320, 320, 1280),
186+
ResnetBlock(320, 320, 1280, device=device, dtype=dtype),
187187
PushBlock(),
188-
DownSampler(320),
188+
DownSampler(320, device=device, dtype=dtype),
189189
PushBlock(),
190190
# CrossAttnDownBlock2D
191-
ResnetBlock(320, 640, 1280),
192-
AttentionBlock(10, 64, 640, 2, 2048),
191+
ResnetBlock(320, 640, 1280, device=device, dtype=dtype),
192+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
193193
PushBlock(),
194-
ResnetBlock(640, 640, 1280),
195-
AttentionBlock(10, 64, 640, 2, 2048),
194+
ResnetBlock(640, 640, 1280, device=device, dtype=dtype),
195+
AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype),
196196
PushBlock(),
197-
DownSampler(640),
197+
DownSampler(640, device=device, dtype=dtype),
198198
PushBlock(),
199199
# CrossAttnDownBlock2D
200-
ResnetBlock(640, 1280, 1280),
201-
AttentionBlock(20, 64, 1280, 10, 2048),
200+
ResnetBlock(640, 1280, 1280, device=device, dtype=dtype),
201+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
202202
PushBlock(),
203-
ResnetBlock(1280, 1280, 1280),
204-
AttentionBlock(20, 64, 1280, 10, 2048),
203+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
204+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
205205
PushBlock(),
206206
# UNetMidBlock2DCrossAttn
207-
ResnetBlock(1280, 1280, 1280),
208-
AttentionBlock(20, 64, 1280, 10, 2048),
209-
ResnetBlock(1280, 1280, 1280),
207+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
208+
AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype),
209+
ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype),
210210
PushBlock()
211211
])
212212

213213
self.controlnet_blocks = torch.nn.ModuleList([
214-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
215-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
216-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
217-
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
218-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
219-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
220-
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
221-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
222-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
223-
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
214+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
215+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
216+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
217+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype),
218+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
219+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
220+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype),
221+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
222+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
223+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype),
224224
])
225225

226226
# 0 -- openpose

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def __call__(
452452

453453
# ControlNet
454454
controlnet_params = self.prepare_controlnet_params(controlnet_params, h=height, w=width)
455-
456455
# Encode prompts
457456
self.load_models_to_device(["text_encoder", "text_encoder_2"])
458457
positive_prompt_emb, positive_add_text_embeds = self.encode_prompt(prompt, clip_skip=clip_skip)

0 commit comments

Comments
 (0)