@@ -27,13 +27,16 @@ class ControlNetBlock : public GGMLBlock {
2727 int num_heads = 8 ;
2828 int num_head_channels = -1 ; // channels // num_heads
2929 int context_dim = 768 ; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
30+ bool direct = false ;
3031
3132public:
3233 int model_channels = 320 ;
3334 int adm_in_channels = 2816 ; // only for VERSION_SDXL
3435
35- ControlNetBlock (SDVersion version = VERSION_SD1)
36- : version(version) {
36+ ControlNetBlock (SDVersion version = VERSION_SD1,
37+ bool direct = false )
38+ : version(version),
39+ direct (direct) {
3740 if (sd_version_is_sd2 (version)) {
3841 context_dim = 1024 ;
3942 num_head_channels = 64 ;
@@ -65,7 +68,7 @@ class ControlNetBlock : public GGMLBlock {
6568 }
6669
6770 // input_blocks
68- blocks[" input_blocks.0.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, model_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
71+ blocks[" input_blocks.0.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, model_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
6972
7073 std::vector<int > input_block_chans;
7174 input_block_chans.push_back (model_channels);
@@ -86,26 +89,26 @@ class ControlNetBlock : public GGMLBlock {
8689 };
8790
8891 auto make_zero_conv = [&](int64_t channels) {
89- return new Conv2d (channels, channels, {1 , 1 });
92+ return new Conv2d (channels, channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , direct );
9093 };
9194
9295 blocks[" zero_convs.0.0" ] = std::shared_ptr<GGMLBlock>(make_zero_conv (model_channels));
9396
94- blocks[" input_hint_block.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (hint_channels, 16 , {3 , 3 }, {1 , 1 }, {1 , 1 }));
97+ blocks[" input_hint_block.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (hint_channels, 16 , {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
9598 // nn.SiLU()
96- blocks[" input_hint_block.2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (16 , 16 , {3 , 3 }, {1 , 1 }, {1 , 1 }));
99+ blocks[" input_hint_block.2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (16 , 16 , {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
97100 // nn.SiLU()
98- blocks[" input_hint_block.4" ] = std::shared_ptr<GGMLBlock>(new Conv2d (16 , 32 , {3 , 3 }, {2 , 2 }, {1 , 1 }));
101+ blocks[" input_hint_block.4" ] = std::shared_ptr<GGMLBlock>(new Conv2d (16 , 32 , {3 , 3 }, {2 , 2 }, {1 , 1 }, { 1 , 1 }, true , direct ));
99102 // nn.SiLU()
100- blocks[" input_hint_block.6" ] = std::shared_ptr<GGMLBlock>(new Conv2d (32 , 32 , {3 , 3 }, {1 , 1 }, {1 , 1 }));
103+ blocks[" input_hint_block.6" ] = std::shared_ptr<GGMLBlock>(new Conv2d (32 , 32 , {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
101104 // nn.SiLU()
102- blocks[" input_hint_block.8" ] = std::shared_ptr<GGMLBlock>(new Conv2d (32 , 96 , {3 , 3 }, {2 , 2 }, {1 , 1 }));
105+ blocks[" input_hint_block.8" ] = std::shared_ptr<GGMLBlock>(new Conv2d (32 , 96 , {3 , 3 }, {2 , 2 }, {1 , 1 }, { 1 , 1 }, true , direct ));
103106 // nn.SiLU()
104- blocks[" input_hint_block.10" ] = std::shared_ptr<GGMLBlock>(new Conv2d (96 , 96 , {3 , 3 }, {1 , 1 }, {1 , 1 }));
107+ blocks[" input_hint_block.10" ] = std::shared_ptr<GGMLBlock>(new Conv2d (96 , 96 , {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
105108 // nn.SiLU()
106- blocks[" input_hint_block.12" ] = std::shared_ptr<GGMLBlock>(new Conv2d (96 , 256 , {3 , 3 }, {2 , 2 }, {1 , 1 }));
109+ blocks[" input_hint_block.12" ] = std::shared_ptr<GGMLBlock>(new Conv2d (96 , 256 , {3 , 3 }, {2 , 2 }, {1 , 1 }, { 1 , 1 }, true , direct ));
107110 // nn.SiLU()
108- blocks[" input_hint_block.14" ] = std::shared_ptr<GGMLBlock>(new Conv2d (256 , model_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
111+ blocks[" input_hint_block.14" ] = std::shared_ptr<GGMLBlock>(new Conv2d (256 , model_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
109112
110113 size_t len_mults = channel_mult.size ();
111114 for (int i = 0 ; i < len_mults; i++) {
@@ -318,8 +321,9 @@ struct ControlNet : public GGMLRunner {
318321
319322 ControlNet (ggml_backend_t backend,
320323 const String2GGMLType& tensor_types = {},
321- SDVersion version = VERSION_SD1)
322- : GGMLRunner(backend), control_net(version) {
324+ SDVersion version = VERSION_SD1,
325+ bool direct = false )
326+ : GGMLRunner(backend), control_net(version, direct) {
323327 control_net.init (params_ctx, tensor_types, " " );
324328 }
325329
@@ -455,4 +459,4 @@ struct ControlNet : public GGMLRunner {
455459 }
456460};
457461
458- #endif // __CONTROL_HPP__
462+ #endif // __CONTROL_HPP__
0 commit comments