Skip to content

Commit 170663f

Browse files
committed
Refactor: fix runtime type override
1 parent 38f5685 commit 170663f

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

model.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,21 @@ ggml_type ModelLoader::get_vae_wtype() {
16071607
return GGML_TYPE_COUNT;
16081608
}
16091609

1610+
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
1611+
for (auto& pair : tensor_storages_types) {
1612+
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
1613+
for (auto& tensor_storage : tensor_storages) {
1614+
if (tensor_storage.name == pair.first) {
1615+
if (tensor_should_be_converted(tensor_storage, wtype)) {
1616+
pair.second = wtype;
1617+
}
1618+
break;
1619+
}
1620+
}
1621+
}
1622+
}
1623+
}
1624+
16101625
std::string ModelLoader::load_merges() {
16111626
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
16121627
return merges_utf8_str;

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class ModelLoader {
186186
ggml_type get_conditioner_wtype();
187187
ggml_type get_diffusion_model_wtype();
188188
ggml_type get_vae_wtype();
189+
void set_wtype_override(ggml_type wtype, std::string prefix = "");
189190
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend);
190191
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
191192
ggml_backend_t backend,

stable-diffusion.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,12 @@ class StableDiffusionGGML {
264264
conditioner_wtype = wtype;
265265
diffusion_model_wtype = wtype;
266266
vae_wtype = wtype;
267+
model_loader.set_wtype_override(wtype);
267268
}
268269

269270
if (version == VERSION_SDXL) {
270271
vae_wtype = GGML_TYPE_F32;
272+
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
271273
}
272274

273275
LOG_INFO("Weight type: %s", model_wtype != SD_TYPE_COUNT ? ggml_type_name(model_wtype) : "??");

0 commit comments

Comments
 (0)