Skip to content

Commit 1aab32f

Browse files
committed
fixed safetensors loading for zimage
1 parent 801840d commit 1aab32f

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

otherarch/sdcpp/stable-diffusion.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,24 @@ class StableDiffusionGGML {
264264

265265
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
266266
int tempver = model_loader.get_sd_version();
267+
268+
// kcpp fallback to separate diffusion model passed as model
269+
if (tempver == VERSION_COUNT &&
270+
strlen(SAFE_STR(sd_ctx_params->model_path)) > 0 &&
271+
strlen(SAFE_STR(sd_ctx_params->diffusion_model_path)) == 0 &&
272+
(t5_path_fixed!=""||clipl_path_fixed!=""))
273+
{
274+
bool endswithsafetensors = ends_with(sd_ctx_params->model_path, ".safetensors");
275+
if(endswithsafetensors && !model_loader.has_diffusion_model_tensors())
276+
{
277+
LOG_INFO("SD Diffusion Model tensors missing! Fallback trying alternative tensor names...\n");
278+
if (!model_loader.init_from_file(sd_ctx_params->model_path, "model.diffusion_model.")) {
279+
LOG_WARN("loading diffusion model from '%s' failed", sd_ctx_params->model_path);
280+
}
281+
tempver = model_loader.get_sd_version();
282+
}
283+
}
284+
267285
bool iswan = (tempver==VERSION_WAN2 || tempver==VERSION_WAN2_2_I2V || tempver==VERSION_WAN2_2_TI2V);
268286
bool isqwenimg = (tempver==VERSION_QWEN_IMAGE);
269287
bool iszimg = (tempver==VERSION_Z_IMAGE);
@@ -370,23 +388,6 @@ class StableDiffusionGGML {
370388

371389
version = model_loader.get_sd_version();
372390

373-
// kcpp fallback to separate diffusion model passed as model
374-
if (version == VERSION_COUNT &&
375-
strlen(SAFE_STR(sd_ctx_params->model_path)) > 0 &&
376-
strlen(SAFE_STR(sd_ctx_params->diffusion_model_path)) == 0 &&
377-
t5_path_fixed!="" )
378-
{
379-
bool endswithsafetensors = ends_with(sd_ctx_params->model_path, ".safetensors");
380-
if(endswithsafetensors && !model_loader.has_diffusion_model_tensors())
381-
{
382-
LOG_INFO("SD Diffusion Model tensors missing! Fallback trying alternative tensor names...\n");
383-
if (!model_loader.init_from_file(sd_ctx_params->model_path, "model.diffusion_model.")) {
384-
LOG_WARN("loading diffusion model from '%s' failed", sd_ctx_params->model_path);
385-
}
386-
version = model_loader.get_sd_version();
387-
}
388-
}
389-
390391
if (version == VERSION_COUNT) {
391392
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
392393
return false;

0 commit comments

Comments
 (0)