@@ -16,15 +16,16 @@ class ResidualDenseBlock : public GGMLBlock {
1616protected:
1717 int num_feat;
1818 int num_grow_ch;
19+ bool direct = false ;
1920
2021public:
21- ResidualDenseBlock (int num_feat = 64 , int num_grow_ch = 32 )
22- : num_feat(num_feat), num_grow_ch(num_grow_ch) {
23- blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
24- blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
25- blocks[" conv3" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 2 * num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
26- blocks[" conv4" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 3 * num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
27- blocks[" conv5" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 4 * num_grow_ch, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
22+ ResidualDenseBlock (int num_feat = 64 , int num_grow_ch = 32 , bool direct = false )
23+ : num_feat(num_feat), num_grow_ch(num_grow_ch), direct(direct) {
24+ blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
25+ blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
26+ blocks[" conv3" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 2 * num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
27+ blocks[" conv4" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 3 * num_grow_ch, num_grow_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
28+ blocks[" conv5" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat + 4 * num_grow_ch, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
2829 }
2930
3031 struct ggml_tensor * lrelu (struct ggml_context * ctx, struct ggml_tensor * x) {
@@ -58,10 +59,10 @@ class ResidualDenseBlock : public GGMLBlock {
5859
5960class RRDB : public GGMLBlock {
6061public:
61- RRDB (int num_feat, int num_grow_ch = 32 ) {
62- blocks[" rdb1" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch));
63- blocks[" rdb2" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch));
64- blocks[" rdb3" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch));
62+ RRDB (int num_feat, int num_grow_ch = 32 , bool direct = false ) {
63+ blocks[" rdb1" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch, direct ));
64+ blocks[" rdb2" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch, direct ));
65+ blocks[" rdb3" ] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock (num_feat, num_grow_ch, direct ));
6566 }
6667
6768 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
@@ -89,20 +90,21 @@ class RRDBNet : public GGMLBlock {
8990 int num_out_ch = 3 ;
9091 int num_feat = 64 ; // default RealESRGAN_x4plus_anime_6B
9192 int num_grow_ch = 32 ; // default RealESRGAN_x4plus_anime_6B
93+ bool direct = false ;
9294
9395public:
94- RRDBNet () {
95- blocks[" conv_first" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_in_ch, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
96+ RRDBNet (bool direct = false ) {
97+ blocks[" conv_first" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_in_ch, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
9698 for (int i = 0 ; i < num_block; i++) {
9799 std::string name = " body." + std::to_string (i);
98100 blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB (num_feat, num_grow_ch));
99101 }
100- blocks[" conv_body" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
102+ blocks[" conv_body" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
101103 // upsample
102- blocks[" conv_up1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
103- blocks[" conv_up2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
104- blocks[" conv_hr" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
105- blocks[" conv_last" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_out_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
104+ blocks[" conv_up1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
105+ blocks[" conv_up2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
106+ blocks[" conv_hr" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
107+ blocks[" conv_last" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_out_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , direct ));
106108 }
107109
108110 struct ggml_tensor * lrelu (struct ggml_context * ctx, struct ggml_tensor * x) {
@@ -142,8 +144,8 @@ struct ESRGAN : public GGMLRunner {
142144 int scale = 4 ;
143145 int tile_size = 128 ; // avoid cuda OOM for 4gb VRAM
144146
145- ESRGAN (ggml_backend_t backend, const String2GGMLType& tensor_types = {})
146- : GGMLRunner(backend) {
147+ ESRGAN (ggml_backend_t backend, const String2GGMLType& tensor_types = {}, bool direct = false )
148+ : GGMLRunner(backend), rrdb_net(direct) {
147149 rrdb_net.init (params_ctx, tensor_types, " " );
148150 }
149151
@@ -194,4 +196,4 @@ struct ESRGAN : public GGMLRunner {
194196 }
195197};
196198
197- #endif // __ESRGAN_HPP__
199+ #endif // __ESRGAN_HPP__
0 commit comments