@@ -83,39 +83,44 @@ class RRDB : public GGMLBlock {
8383
8484class RRDBNet : public GGMLBlock {
8585protected:
86- int scale = 4 ; // default RealESRGAN_x4plus_anime_6B
87- int num_block = 6 ; // default RealESRGAN_x4plus_anime_6B
86+ int scale = 4 ;
87+ int num_block = 23 ;
8888 int num_in_ch = 3 ;
8989 int num_out_ch = 3 ;
90- int num_feat = 64 ; // default RealESRGAN_x4plus_anime_6B
91- int num_grow_ch = 32 ; // default RealESRGAN_x4plus_anime_6B
90+ int num_feat = 64 ;
91+ int num_grow_ch = 32 ;
9292
9393public:
94- RRDBNet () {
94+ RRDBNet (int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
95+ : scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
9596 blocks[" conv_first" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_in_ch, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
9697 for (int i = 0 ; i < num_block; i++) {
9798 std::string name = " body." + std::to_string (i);
9899 blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB (num_feat, num_grow_ch));
99100 }
100101 blocks[" conv_body" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
101- // upsample
102- blocks[" conv_up1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
103- blocks[" conv_up2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
102+ if (scale >= 2 ) {
103+ blocks[" conv_up1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
104+ }
105+ if (scale == 4 ) {
106+ blocks[" conv_up2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
107+ }
104108 blocks[" conv_hr" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_feat, {3 , 3 }, {1 , 1 }, {1 , 1 }));
105109 blocks[" conv_last" ] = std::shared_ptr<GGMLBlock>(new Conv2d (num_feat, num_out_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
106110 }
107111
112+ int get_scale () { return scale; }
113+ int get_num_block () { return num_block; }
114+
108115 struct ggml_tensor * lrelu (struct ggml_context * ctx, struct ggml_tensor * x) {
109116 return ggml_leaky_relu (ctx, x, 0 .2f , true );
110117 }
111118
112119 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
113120 // x: [n, num_in_ch, h, w]
114- // return: [n, num_out_ch, h*4 , w*4 ]
121+ // return: [n, num_out_ch, h*scale , w*scale ]
115122 auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_first" ]);
116123 auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_body" ]);
117- auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_up1" ]);
118- auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_up2" ]);
119124 auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_hr" ]);
120125 auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_last" ]);
121126
@@ -130,28 +135,37 @@ class RRDBNet : public GGMLBlock {
130135 body_feat = conv_body->forward (ctx, body_feat);
131136 feat = ggml_add (ctx, feat, body_feat);
132137 // upsample
133- feat = lrelu (ctx, conv_up1->forward (ctx, ggml_upscale (ctx, feat, 2 , GGML_SCALE_MODE_NEAREST)));
134- feat = lrelu (ctx, conv_up2->forward (ctx, ggml_upscale (ctx, feat, 2 , GGML_SCALE_MODE_NEAREST)));
138+ if (scale >= 2 ) {
139+ auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_up1" ]);
140+ feat = lrelu (ctx, conv_up1->forward (ctx, ggml_upscale (ctx, feat, 2 , GGML_SCALE_MODE_NEAREST)));
141+ if (scale == 4 ) {
142+ auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv_up2" ]);
143+ feat = lrelu (ctx, conv_up2->forward (ctx, ggml_upscale (ctx, feat, 2 , GGML_SCALE_MODE_NEAREST)));
144+ }
145+ }
146+ // for all scales
135147 auto out = conv_last->forward (ctx, lrelu (ctx, conv_hr->forward (ctx, feat)));
136148 return out;
137149 }
138150};
139151
140152struct ESRGAN : public GGMLRunner {
141- RRDBNet rrdb_net;
153+ std::unique_ptr< RRDBNet> rrdb_net;
142154 int scale = 4 ;
143155 int tile_size = 128 ; // avoid cuda OOM for 4gb VRAM
144156
145157 ESRGAN (ggml_backend_t backend,
146158 bool offload_params_to_cpu,
147159 const String2GGMLType& tensor_types = {})
148160 : GGMLRunner(backend, offload_params_to_cpu) {
149- rrdb_net. init (params_ctx, tensor_types, " " );
161+ // rrdb_net will be created in load_from_file
150162 }
151163
152164 void enable_conv2d_direct () {
165+ if (!rrdb_net)
166+ return ;
153167 std::vector<GGMLBlock*> blocks;
154- rrdb_net. get_all_blocks (blocks);
168+ rrdb_net-> get_all_blocks (blocks);
155169 for (auto block : blocks) {
156170 if (block->get_desc () == " Conv2d" ) {
157171 auto conv_block = (Conv2d*)block;
@@ -167,31 +181,185 @@ struct ESRGAN : public GGMLRunner {
167181 bool load_from_file (const std::string& file_path, int n_threads) {
168182 LOG_INFO (" loading esrgan from '%s'" , file_path.c_str ());
169183
170- alloc_params_buffer ();
171- std::map<std::string, ggml_tensor*> esrgan_tensors;
172- rrdb_net.get_param_tensors (esrgan_tensors);
173-
174184 ModelLoader model_loader;
175185 if (!model_loader.init_from_file (file_path)) {
176186 LOG_ERROR (" init esrgan model loader from file failed: '%s'" , file_path.c_str ());
177187 return false ;
178188 }
179189
180- bool success = model_loader.load_tensors (esrgan_tensors, {}, n_threads);
190+ // Get tensor names
191+ auto tensor_names = model_loader.get_tensor_names ();
192+
193+ // Detect if it's ESRGAN format
194+ bool is_ESRGAN = std::find (tensor_names.begin (), tensor_names.end (), " model.0.weight" ) != tensor_names.end ();
195+
196+ // Detect parameters from tensor names
197+ int detected_num_block = 0 ;
198+ if (is_ESRGAN) {
199+ for (const auto & name : tensor_names) {
200+ if (name.find (" model.1.sub." ) == 0 ) {
201+ size_t first_dot = name.find (' .' , 12 );
202+ if (first_dot != std::string::npos) {
203+ size_t second_dot = name.find (' .' , first_dot + 1 );
204+ if (second_dot != std::string::npos && name.substr (first_dot + 1 , 3 ) == " RDB" ) {
205+ try {
206+ int idx = std::stoi (name.substr (12 , first_dot - 12 ));
207+ detected_num_block = std::max (detected_num_block, idx + 1 );
208+ } catch (...) {
209+ }
210+ }
211+ }
212+ }
213+ }
214+ } else {
215+ // Original format
216+ for (const auto & name : tensor_names) {
217+ if (name.find (" body." ) == 0 ) {
218+ size_t pos = name.find (' .' , 5 );
219+ if (pos != std::string::npos) {
220+ try {
221+ int idx = std::stoi (name.substr (5 , pos - 5 ));
222+ detected_num_block = std::max (detected_num_block, idx + 1 );
223+ } catch (...) {
224+ }
225+ }
226+ }
227+ }
228+ }
229+
230+ int detected_scale = 4 ; // default
231+ if (is_ESRGAN) {
232+ // For ESRGAN format, detect scale by highest model number
233+ int max_model_num = 0 ;
234+ for (const auto & name : tensor_names) {
235+ if (name.find (" model." ) == 0 ) {
236+ size_t dot_pos = name.find (' .' , 6 );
237+ if (dot_pos != std::string::npos) {
238+ try {
239+ int num = std::stoi (name.substr (6 , dot_pos - 6 ));
240+ max_model_num = std::max (max_model_num, num);
241+ } catch (...) {
242+ }
243+ }
244+ }
245+ }
246+ if (max_model_num <= 4 ) {
247+ detected_scale = 1 ;
248+ } else if (max_model_num <= 7 ) {
249+ detected_scale = 2 ;
250+ } else {
251+ detected_scale = 4 ;
252+ }
253+ } else {
254+ // Original format
255+ bool has_conv_up2 = std::any_of (tensor_names.begin (), tensor_names.end (), [](const std::string& name) {
256+ return name == " conv_up2.weight" ;
257+ });
258+ bool has_conv_up1 = std::any_of (tensor_names.begin (), tensor_names.end (), [](const std::string& name) {
259+ return name == " conv_up1.weight" ;
260+ });
261+ if (has_conv_up2) {
262+ detected_scale = 4 ;
263+ } else if (has_conv_up1) {
264+ detected_scale = 2 ;
265+ } else {
266+ detected_scale = 1 ;
267+ }
268+ }
269+
270+ int detected_num_in_ch = 3 ;
271+ int detected_num_out_ch = 3 ;
272+ int detected_num_feat = 64 ;
273+ int detected_num_grow_ch = 32 ;
274+
275+ // Create RRDBNet with detected parameters
276+ rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
277+ rrdb_net->init (params_ctx, {}, " " );
278+
279+ alloc_params_buffer ();
280+ std::map<std::string, ggml_tensor*> esrgan_tensors;
281+ rrdb_net->get_param_tensors (esrgan_tensors);
282+
283+ bool success;
284+ if (is_ESRGAN) {
285+ // Build name mapping for ESRGAN format
286+ std::map<std::string, std::string> expected_to_model;
287+ expected_to_model[" conv_first.weight" ] = " model.0.weight" ;
288+ expected_to_model[" conv_first.bias" ] = " model.0.bias" ;
289+
290+ for (int i = 0 ; i < detected_num_block; i++) {
291+ for (int j = 1 ; j <= 3 ; j++) {
292+ for (int k = 1 ; k <= 5 ; k++) {
293+ std::string expected_weight = " body." + std::to_string (i) + " .rdb" + std::to_string (j) + " .conv" + std::to_string (k) + " .weight" ;
294+ std::string model_weight = " model.1.sub." + std::to_string (i) + " .RDB" + std::to_string (j) + " .conv" + std::to_string (k) + " .0.weight" ;
295+ expected_to_model[expected_weight] = model_weight;
296+
297+ std::string expected_bias = " body." + std::to_string (i) + " .rdb" + std::to_string (j) + " .conv" + std::to_string (k) + " .bias" ;
298+ std::string model_bias = " model.1.sub." + std::to_string (i) + " .RDB" + std::to_string (j) + " .conv" + std::to_string (k) + " .0.bias" ;
299+ expected_to_model[expected_bias] = model_bias;
300+ }
301+ }
302+ }
303+
304+ if (detected_scale == 1 ) {
305+ expected_to_model[" conv_body.weight" ] = " model.1.sub." + std::to_string (detected_num_block) + " .weight" ;
306+ expected_to_model[" conv_body.bias" ] = " model.1.sub." + std::to_string (detected_num_block) + " .bias" ;
307+ expected_to_model[" conv_hr.weight" ] = " model.2.weight" ;
308+ expected_to_model[" conv_hr.bias" ] = " model.2.bias" ;
309+ expected_to_model[" conv_last.weight" ] = " model.4.weight" ;
310+ expected_to_model[" conv_last.bias" ] = " model.4.bias" ;
311+ } else {
312+ expected_to_model[" conv_body.weight" ] = " model.1.sub." + std::to_string (detected_num_block) + " .weight" ;
313+ expected_to_model[" conv_body.bias" ] = " model.1.sub." + std::to_string (detected_num_block) + " .bias" ;
314+ if (detected_scale >= 2 ) {
315+ expected_to_model[" conv_up1.weight" ] = " model.3.weight" ;
316+ expected_to_model[" conv_up1.bias" ] = " model.3.bias" ;
317+ }
318+ if (detected_scale == 4 ) {
319+ expected_to_model[" conv_up2.weight" ] = " model.6.weight" ;
320+ expected_to_model[" conv_up2.bias" ] = " model.6.bias" ;
321+ expected_to_model[" conv_hr.weight" ] = " model.8.weight" ;
322+ expected_to_model[" conv_hr.bias" ] = " model.8.bias" ;
323+ expected_to_model[" conv_last.weight" ] = " model.10.weight" ;
324+ expected_to_model[" conv_last.bias" ] = " model.10.bias" ;
325+ } else if (detected_scale == 2 ) {
326+ expected_to_model[" conv_hr.weight" ] = " model.5.weight" ;
327+ expected_to_model[" conv_hr.bias" ] = " model.5.bias" ;
328+ expected_to_model[" conv_last.weight" ] = " model.7.weight" ;
329+ expected_to_model[" conv_last.bias" ] = " model.7.bias" ;
330+ }
331+ }
332+
333+ std::map<std::string, ggml_tensor*> model_tensors;
334+ for (auto & p : esrgan_tensors) {
335+ auto it = expected_to_model.find (p.first );
336+ if (it != expected_to_model.end ()) {
337+ model_tensors[it->second ] = p.second ;
338+ }
339+ }
340+
341+ success = model_loader.load_tensors (model_tensors, {}, n_threads);
342+ } else {
343+ success = model_loader.load_tensors (esrgan_tensors, {}, n_threads);
344+ }
181345
182346 if (!success) {
183347 LOG_ERROR (" load esrgan tensors from model loader failed" );
184348 return false ;
185349 }
186350
187- LOG_INFO (" esrgan model loaded" );
351+ scale = rrdb_net->get_scale ();
352+ LOG_INFO (" esrgan model loaded with scale=%d, num_block=%d" , scale, detected_num_block);
188353 return success;
189354 }
190355
191356 struct ggml_cgraph * build_graph (struct ggml_tensor * x) {
192- struct ggml_cgraph * gf = ggml_new_graph (compute_ctx);
193- x = to_backend (x);
194- struct ggml_tensor * out = rrdb_net.forward (compute_ctx, x);
357+ if (!rrdb_net)
358+ return nullptr ;
359+ constexpr int kGraphNodes = 1 << 16 ; // 65k
360+ struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, kGraphNodes , /* grads*/ false );
361+ x = to_backend (x);
362+ struct ggml_tensor * out = rrdb_net->forward (compute_ctx, x);
195363 ggml_build_forward_expand (gf, out);
196364 return gf;
197365 }
0 commit comments