Skip to content

Commit 9b6339c

Browse files
committed
add conv2d direct for esrgan
1 parent 8974ec1 commit 9b6339c

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed

esrgan.hpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@ class ResidualDenseBlock : public GGMLBlock {
1616
protected:
1717
int num_feat;
1818
int num_grow_ch;
19+
bool direct = false;
1920

2021
public:
21-
ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32)
22-
: num_feat(num_feat), num_grow_ch(num_grow_ch) {
23-
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
24-
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
25-
blocks["conv3"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
26-
blocks["conv4"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
27-
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
22+
ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32, bool direct = false)
23+
: num_feat(num_feat), num_grow_ch(num_grow_ch), direct(direct) {
24+
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_grow_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
25+
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
26+
blocks["conv3"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
27+
blocks["conv4"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
28+
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
2829
}
2930

3031
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -58,10 +59,10 @@ class ResidualDenseBlock : public GGMLBlock {
5859

5960
class RRDB : public GGMLBlock {
6061
public:
61-
RRDB(int num_feat, int num_grow_ch = 32) {
62-
blocks["rdb1"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
63-
blocks["rdb2"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
64-
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
62+
RRDB(int num_feat, int num_grow_ch = 32, bool direct = false) {
63+
blocks["rdb1"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch, direct));
64+
blocks["rdb2"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch, direct));
65+
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch, direct));
6566
}
6667

6768
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -89,20 +90,21 @@ class RRDBNet : public GGMLBlock {
8990
int num_out_ch = 3;
9091
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
9192
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
93+
bool direct = false;
9294

9395
public:
94-
RRDBNet() {
95-
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
96+
RRDBNet(bool direct = false) {
97+
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
9698
for (int i = 0; i < num_block; i++) {
9799
std::string name = "body." + std::to_string(i);
98100
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
99101
}
100-
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
102+
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
101103
// upsample
102-
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
103-
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
104-
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
105-
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
104+
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
105+
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
106+
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
107+
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, direct));
106108
}
107109

108110
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -142,8 +144,8 @@ struct ESRGAN : public GGMLRunner {
142144
int scale = 4;
143145
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144146

145-
ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {})
146-
: GGMLRunner(backend) {
147+
ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {}, bool direct = false)
148+
: GGMLRunner(backend), rrdb_net(direct) {
147149
rrdb_net.init(params_ctx, tensor_types, "");
148150
}
149151

@@ -194,4 +196,4 @@ struct ESRGAN : public GGMLRunner {
194196
}
195197
};
196198

197-
#endif // __ESRGAN_HPP__
199+
#endif // __ESRGAN_HPP__

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,8 @@ int main(int argc, const char* argv[]) {
10241024
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
10251025
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
10261026
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
1027-
params.n_threads);
1027+
params.n_threads,
1028+
params.diffusion_conv_direct);
10281029

10291030
if (upscaler_ctx == NULL) {
10301031
printf("new_upscaler_ctx failed\n");

stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
238238
typedef struct upscaler_ctx_t upscaler_ctx_t;
239239

240240
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
241-
int n_threads);
241+
int n_threads,
242+
bool direct);
242243
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
243244

244245
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

upscaler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ struct UpscalerGGML {
99
std::shared_ptr<ESRGAN> esrgan_upscaler;
1010
std::string esrgan_path;
1111
int n_threads;
12+
bool direct = false;
1213

13-
UpscalerGGML(int n_threads)
14-
: n_threads(n_threads) {
14+
UpscalerGGML(int n_threads,
15+
bool direct = false)
16+
: n_threads(n_threads),
17+
direct(direct) {
1518
}
1619

1720
bool load_from_file(const std::string& esrgan_path) {
@@ -46,7 +49,7 @@ struct UpscalerGGML {
4649
backend = ggml_backend_cpu_init();
4750
}
4851
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
49-
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
52+
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types, direct);
5053
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
5154
return false;
5255
}
@@ -104,14 +107,15 @@ struct upscaler_ctx_t {
104107
};
105108

106109
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
107-
int n_threads) {
110+
int n_threads,
111+
bool direct = false) {
108112
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
109113
if (upscaler_ctx == NULL) {
110114
return NULL;
111115
}
112116
std::string esrgan_path(esrgan_path_c_str);
113117

114-
upscaler_ctx->upscaler = new UpscalerGGML(n_threads);
118+
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
115119
if (upscaler_ctx->upscaler == NULL) {
116120
return NULL;
117121
}

0 commit comments

Comments
 (0)