Skip to content

Commit 8974ec1

Browse files
committed
add conv2d direct for controlnet
1 parent 9a349b2 commit 8974ec1

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

control.hpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ class ControlNetBlock : public GGMLBlock {
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
2929
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
30+
bool direct = false;
3031

3132
public:
3233
int model_channels = 320;
3334
int adm_in_channels = 2816; // only for VERSION_SDXL
3435

35-
ControlNetBlock(SDVersion version = VERSION_SD1)
36-
: version(version) {
36+
ControlNetBlock(SDVersion version = VERSION_SD1,
37+
bool direct = false)
38+
: version(version),
39+
direct(direct) {
3740
if (sd_version_is_sd2(version)) {
3841
context_dim = 1024;
3942
num_head_channels = 64;
@@ -65,7 +68,7 @@ class ControlNetBlock : public GGMLBlock {
6568
}
6669

6770
// input_blocks
68-
blocks["input_blocks.0.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, model_channels, {3, 3}, {1, 1}, {1, 1}));
71+
blocks["input_blocks.0.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, model_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
6972

7073
std::vector<int> input_block_chans;
7174
input_block_chans.push_back(model_channels);
@@ -86,26 +89,26 @@ class ControlNetBlock : public GGMLBlock {
8689
};
8790

8891
auto make_zero_conv = [&](int64_t channels) {
89-
return new Conv2d(channels, channels, {1, 1});
92+
return new Conv2d(channels, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, direct);
9093
};
9194

9295
blocks["zero_convs.0.0"] = std::shared_ptr<GGMLBlock>(make_zero_conv(model_channels));
9396

94-
blocks["input_hint_block.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(hint_channels, 16, {3, 3}, {1, 1}, {1, 1}));
97+
blocks["input_hint_block.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(hint_channels, 16, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
9598
// nn.SiLU()
96-
blocks["input_hint_block.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(16, 16, {3, 3}, {1, 1}, {1, 1}));
99+
blocks["input_hint_block.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(16, 16, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
97100
// nn.SiLU()
98-
blocks["input_hint_block.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(16, 32, {3, 3}, {2, 2}, {1, 1}));
101+
blocks["input_hint_block.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(16, 32, {3, 3}, {2, 2}, {1, 1}, {1, 1}, true, direct));
99102
// nn.SiLU()
100-
blocks["input_hint_block.6"] = std::shared_ptr<GGMLBlock>(new Conv2d(32, 32, {3, 3}, {1, 1}, {1, 1}));
103+
blocks["input_hint_block.6"] = std::shared_ptr<GGMLBlock>(new Conv2d(32, 32, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
101104
// nn.SiLU()
102-
blocks["input_hint_block.8"] = std::shared_ptr<GGMLBlock>(new Conv2d(32, 96, {3, 3}, {2, 2}, {1, 1}));
105+
blocks["input_hint_block.8"] = std::shared_ptr<GGMLBlock>(new Conv2d(32, 96, {3, 3}, {2, 2}, {1, 1}, {1, 1}, true, direct));
103106
// nn.SiLU()
104-
blocks["input_hint_block.10"] = std::shared_ptr<GGMLBlock>(new Conv2d(96, 96, {3, 3}, {1, 1}, {1, 1}));
107+
blocks["input_hint_block.10"] = std::shared_ptr<GGMLBlock>(new Conv2d(96, 96, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
105108
// nn.SiLU()
106-
blocks["input_hint_block.12"] = std::shared_ptr<GGMLBlock>(new Conv2d(96, 256, {3, 3}, {2, 2}, {1, 1}));
109+
blocks["input_hint_block.12"] = std::shared_ptr<GGMLBlock>(new Conv2d(96, 256, {3, 3}, {2, 2}, {1, 1}, {1, 1}, true, direct));
107110
// nn.SiLU()
108-
blocks["input_hint_block.14"] = std::shared_ptr<GGMLBlock>(new Conv2d(256, model_channels, {3, 3}, {1, 1}, {1, 1}));
111+
blocks["input_hint_block.14"] = std::shared_ptr<GGMLBlock>(new Conv2d(256, model_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
109112

110113
size_t len_mults = channel_mult.size();
111114
for (int i = 0; i < len_mults; i++) {
@@ -318,8 +321,9 @@ struct ControlNet : public GGMLRunner {
318321

319322
ControlNet(ggml_backend_t backend,
320323
const String2GGMLType& tensor_types = {},
321-
SDVersion version = VERSION_SD1)
322-
: GGMLRunner(backend), control_net(version) {
324+
SDVersion version = VERSION_SD1,
325+
bool direct = false)
326+
: GGMLRunner(backend), control_net(version, direct) {
323327
control_net.init(params_ctx, tensor_types, "");
324328
}
325329

@@ -455,4 +459,4 @@ struct ControlNet : public GGMLRunner {
455459
}
456460
};
457461

458-
#endif // __CONTROL_HPP__
462+
#endif // __CONTROL_HPP__

stable-diffusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ class StableDiffusionGGML {
423423
} else {
424424
controlnet_backend = backend;
425425
}
426-
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
426+
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version, sd_ctx_params->diffusion_conv_direct);
427427
}
428428

429429
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {

0 commit comments

Comments
 (0)