Skip to content

Commit fef7391

Browse files
wbrunaLostRuins
andauthored
sd: clean up changes against stable-diffusion.cpp 90ef5f8 (LostRuins#1804)
* sd: clean up changes against stable-diffusion.cpp 90ef5f8 Clean up the diff, and include a few missing changes, mainly from the upscaler and model weight type statistics. * added line clear again * remove excess spaces --------- Co-authored-by: LostRuins Concedo <[email protected]>
1 parent 0aaa8ca commit fef7391

18 files changed

+286
-112
lines changed

otherarch/sdcpp/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,4 +563,4 @@ class VideoResBlock : public ResBlock {
563563
}
564564
};
565565

566-
#endif // __COMMON_HPP__
566+
#endif // __COMMON_HPP__

otherarch/sdcpp/conditioner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1574,4 +1574,4 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
15741574
}
15751575
};
15761576

1577-
#endif
1577+
#endif

otherarch/sdcpp/diffusion_model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,4 @@ struct QwenImageModel : public DiffusionModel {
320320
}
321321
};
322322

323-
#endif
323+
#endif

otherarch/sdcpp/esrgan.hpp

Lines changed: 193 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -83,39 +83,44 @@ class RRDB : public GGMLBlock {
8383

8484
class RRDBNet : public GGMLBlock {
8585
protected:
86-
int scale = 4; // default RealESRGAN_x4plus_anime_6B
87-
int num_block = 6; // default RealESRGAN_x4plus_anime_6B
86+
int scale = 4;
87+
int num_block = 23;
8888
int num_in_ch = 3;
8989
int num_out_ch = 3;
90-
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
91-
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
90+
int num_feat = 64;
91+
int num_grow_ch = 32;
9292

9393
public:
94-
RRDBNet() {
94+
RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
95+
: scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
9596
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
9697
for (int i = 0; i < num_block; i++) {
9798
std::string name = "body." + std::to_string(i);
9899
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
99100
}
100101
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
101-
// upsample
102-
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
103-
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
102+
if (scale >= 2) {
103+
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
104+
}
105+
if (scale == 4) {
106+
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
107+
}
104108
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
105109
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
106110
}
107111

112+
int get_scale() { return scale; }
113+
int get_num_block() { return num_block; }
114+
108115
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
109116
return ggml_leaky_relu(ctx, x, 0.2f, true);
110117
}
111118

112119
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
113120
// x: [n, num_in_ch, h, w]
114-
// return: [n, num_out_ch, h*4, w*4]
121+
// return: [n, num_out_ch, h*scale, w*scale]
115122
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
116123
auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]);
117-
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
118-
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
119124
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
120125
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
121126

@@ -130,28 +135,37 @@ class RRDBNet : public GGMLBlock {
130135
body_feat = conv_body->forward(ctx, body_feat);
131136
feat = ggml_add(ctx, feat, body_feat);
132137
// upsample
133-
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
134-
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
138+
if (scale >= 2) {
139+
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
140+
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
141+
if (scale == 4) {
142+
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
143+
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
144+
}
145+
}
146+
// for all scales
135147
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
136148
return out;
137149
}
138150
};
139151

140152
struct ESRGAN : public GGMLRunner {
141-
RRDBNet rrdb_net;
153+
std::unique_ptr<RRDBNet> rrdb_net;
142154
int scale = 4;
143155
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
144156

145157
ESRGAN(ggml_backend_t backend,
146158
bool offload_params_to_cpu,
147159
const String2GGMLType& tensor_types = {})
148160
: GGMLRunner(backend, offload_params_to_cpu) {
149-
rrdb_net.init(params_ctx, tensor_types, "");
161+
// rrdb_net will be created in load_from_file
150162
}
151163

152164
void enable_conv2d_direct() {
165+
if (!rrdb_net)
166+
return;
153167
std::vector<GGMLBlock*> blocks;
154-
rrdb_net.get_all_blocks(blocks);
168+
rrdb_net->get_all_blocks(blocks);
155169
for (auto block : blocks) {
156170
if (block->get_desc() == "Conv2d") {
157171
auto conv_block = (Conv2d*)block;
@@ -167,31 +181,185 @@ struct ESRGAN : public GGMLRunner {
167181
bool load_from_file(const std::string& file_path, int n_threads) {
168182
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
169183

170-
alloc_params_buffer();
171-
std::map<std::string, ggml_tensor*> esrgan_tensors;
172-
rrdb_net.get_param_tensors(esrgan_tensors);
173-
174184
ModelLoader model_loader;
175185
if (!model_loader.init_from_file(file_path)) {
176186
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
177187
return false;
178188
}
179189

180-
bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
190+
// Get tensor names
191+
auto tensor_names = model_loader.get_tensor_names();
192+
193+
// Detect if it's ESRGAN format
194+
bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end();
195+
196+
// Detect parameters from tensor names
197+
int detected_num_block = 0;
198+
if (is_ESRGAN) {
199+
for (const auto& name : tensor_names) {
200+
if (name.find("model.1.sub.") == 0) {
201+
size_t first_dot = name.find('.', 12);
202+
if (first_dot != std::string::npos) {
203+
size_t second_dot = name.find('.', first_dot + 1);
204+
if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") {
205+
try {
206+
int idx = std::stoi(name.substr(12, first_dot - 12));
207+
detected_num_block = std::max(detected_num_block, idx + 1);
208+
} catch (...) {
209+
}
210+
}
211+
}
212+
}
213+
}
214+
} else {
215+
// Original format
216+
for (const auto& name : tensor_names) {
217+
if (name.find("body.") == 0) {
218+
size_t pos = name.find('.', 5);
219+
if (pos != std::string::npos) {
220+
try {
221+
int idx = std::stoi(name.substr(5, pos - 5));
222+
detected_num_block = std::max(detected_num_block, idx + 1);
223+
} catch (...) {
224+
}
225+
}
226+
}
227+
}
228+
}
229+
230+
int detected_scale = 4; // default
231+
if (is_ESRGAN) {
232+
// For ESRGAN format, detect scale by highest model number
233+
int max_model_num = 0;
234+
for (const auto& name : tensor_names) {
235+
if (name.find("model.") == 0) {
236+
size_t dot_pos = name.find('.', 6);
237+
if (dot_pos != std::string::npos) {
238+
try {
239+
int num = std::stoi(name.substr(6, dot_pos - 6));
240+
max_model_num = std::max(max_model_num, num);
241+
} catch (...) {
242+
}
243+
}
244+
}
245+
}
246+
if (max_model_num <= 4) {
247+
detected_scale = 1;
248+
} else if (max_model_num <= 7) {
249+
detected_scale = 2;
250+
} else {
251+
detected_scale = 4;
252+
}
253+
} else {
254+
// Original format
255+
bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
256+
return name == "conv_up2.weight";
257+
});
258+
bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
259+
return name == "conv_up1.weight";
260+
});
261+
if (has_conv_up2) {
262+
detected_scale = 4;
263+
} else if (has_conv_up1) {
264+
detected_scale = 2;
265+
} else {
266+
detected_scale = 1;
267+
}
268+
}
269+
270+
int detected_num_in_ch = 3;
271+
int detected_num_out_ch = 3;
272+
int detected_num_feat = 64;
273+
int detected_num_grow_ch = 32;
274+
275+
// Create RRDBNet with detected parameters
276+
rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
277+
rrdb_net->init(params_ctx, {}, "");
278+
279+
alloc_params_buffer();
280+
std::map<std::string, ggml_tensor*> esrgan_tensors;
281+
rrdb_net->get_param_tensors(esrgan_tensors);
282+
283+
bool success;
284+
if (is_ESRGAN) {
285+
// Build name mapping for ESRGAN format
286+
std::map<std::string, std::string> expected_to_model;
287+
expected_to_model["conv_first.weight"] = "model.0.weight";
288+
expected_to_model["conv_first.bias"] = "model.0.bias";
289+
290+
for (int i = 0; i < detected_num_block; i++) {
291+
for (int j = 1; j <= 3; j++) {
292+
for (int k = 1; k <= 5; k++) {
293+
std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight";
294+
std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight";
295+
expected_to_model[expected_weight] = model_weight;
296+
297+
std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias";
298+
std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias";
299+
expected_to_model[expected_bias] = model_bias;
300+
}
301+
}
302+
}
303+
304+
if (detected_scale == 1) {
305+
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
306+
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
307+
expected_to_model["conv_hr.weight"] = "model.2.weight";
308+
expected_to_model["conv_hr.bias"] = "model.2.bias";
309+
expected_to_model["conv_last.weight"] = "model.4.weight";
310+
expected_to_model["conv_last.bias"] = "model.4.bias";
311+
} else {
312+
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
313+
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
314+
if (detected_scale >= 2) {
315+
expected_to_model["conv_up1.weight"] = "model.3.weight";
316+
expected_to_model["conv_up1.bias"] = "model.3.bias";
317+
}
318+
if (detected_scale == 4) {
319+
expected_to_model["conv_up2.weight"] = "model.6.weight";
320+
expected_to_model["conv_up2.bias"] = "model.6.bias";
321+
expected_to_model["conv_hr.weight"] = "model.8.weight";
322+
expected_to_model["conv_hr.bias"] = "model.8.bias";
323+
expected_to_model["conv_last.weight"] = "model.10.weight";
324+
expected_to_model["conv_last.bias"] = "model.10.bias";
325+
} else if (detected_scale == 2) {
326+
expected_to_model["conv_hr.weight"] = "model.5.weight";
327+
expected_to_model["conv_hr.bias"] = "model.5.bias";
328+
expected_to_model["conv_last.weight"] = "model.7.weight";
329+
expected_to_model["conv_last.bias"] = "model.7.bias";
330+
}
331+
}
332+
333+
std::map<std::string, ggml_tensor*> model_tensors;
334+
for (auto& p : esrgan_tensors) {
335+
auto it = expected_to_model.find(p.first);
336+
if (it != expected_to_model.end()) {
337+
model_tensors[it->second] = p.second;
338+
}
339+
}
340+
341+
success = model_loader.load_tensors(model_tensors, {}, n_threads);
342+
} else {
343+
success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
344+
}
181345

182346
if (!success) {
183347
LOG_ERROR("load esrgan tensors from model loader failed");
184348
return false;
185349
}
186350

187-
LOG_INFO("esrgan model loaded");
351+
scale = rrdb_net->get_scale();
352+
LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block);
188353
return success;
189354
}
190355

191356
struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
192-
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
193-
x = to_backend(x);
194-
struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x);
357+
if (!rrdb_net)
358+
return nullptr;
359+
constexpr int kGraphNodes = 1 << 16; // 65k
360+
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
361+
x = to_backend(x);
362+
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
195363
ggml_build_forward_expand(gf, out);
196364
return gf;
197365
}

otherarch/sdcpp/flux.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,4 +1107,4 @@ namespace Flux {
11071107

11081108
} // namespace Flux
11091109

1110-
#endif // __FLUX_HPP__
1110+
#endif // __FLUX_HPP__

otherarch/sdcpp/ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2382,4 +2382,4 @@ class MultiheadAttention : public GGMLBlock {
23822382
}
23832383
};
23842384

2385-
#endif // __GGML_EXTEND__HPP__
2385+
#endif // __GGML_EXTEND__HPP__

otherarch/sdcpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,4 +1586,4 @@ int main(int argc, const char* argv[]) {
15861586
release_all_resources();
15871587

15881588
return 0;
1589-
}
1589+
}

0 commit comments

Comments
 (0)