Skip to content

Commit 38f5685

Browse files
committed
refactor: upscaler
1 parent 04ca926 commit 38f5685

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

esrgan.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,11 @@ struct ESRGAN : public GGMLRunner {
142142
int scale = 4;
143143
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144144

145-
ESRGAN(ggml_backend_t backend)
145+
ESRGAN(ggml_backend_t backend,std::map<std::string, enum ggml_type>& tensor_types)
146146
: GGMLRunner(backend) {
147+
rrdb_net.init(params_ctx, tensor_types, "");
147148
}
148-
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
149-
rrdb_net.init(params_ctx, tensor_types, prefix);
150-
}
149+
151150

152151
std::string get_desc() {
153152
return "esrgan";

upscaler.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ struct UpscalerGGML {
3232
LOG_DEBUG("Using SYCL backend");
3333
backend = ggml_backend_sycl_init(0);
3434
#endif
35-
35+
ModelLoader model_loader;
36+
if (!model_loader.init_from_file(esrgan_path)) {
37+
LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str());
38+
}
3639
if (!backend) {
3740
LOG_DEBUG("Using CPU backend");
3841
backend = ggml_backend_cpu_init();
3942
}
4043
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
41-
esrgan_upscaler = std::make_shared<ESRGAN>(backend);
44+
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
4245
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
4346
return false;
4447
}

0 commit comments

Comments
 (0)