Skip to content

Commit ffa05a0

Browse files
Lincoln Steinhipsterusername
authored andcommitted
Only replace vae when it is the broken SDXL 1.0 version
1 parent a20e173 commit ffa05a0

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
2+
"""
3+
This module exports the function has_baked_in_sdxl_vae().
4+
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
5+
which doesn't work properly in fp16 mode.
6+
"""
7+
8+
import hashlib
9+
from pathlib import Path
10+
11+
from safetensors.torch import load_file
12+
13+
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
14+
15+
16+
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
17+
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
18+
hash = _vae_hash(checkpoint_path)
19+
return hash != SDXL_1_0_VAE_HASH
20+
21+
22+
def _vae_hash(checkpoint_path: Path) -> str:
23+
checkpoint = load_file(checkpoint_path, device="cpu")
24+
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
25+
hash = hashlib.new("sha256")
26+
for key in vae_keys:
27+
value = checkpoint[key]
28+
hash.update(bytes(key, "UTF-8"))
29+
hash.update(bytes(str(value), "UTF-8"))
30+
31+
return hash.hexdigest()

invokeai/backend/model_management/models/sdxl.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import json
22
import os
33
from enum import Enum
4+
from pathlib import Path
45
from typing import Literal, Optional
56

67
from omegaconf import OmegaConf
78
from pydantic import Field
89

10+
from invokeai.app.services.config import InvokeAIAppConfig
11+
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
12+
from invokeai.backend.util.logging import InvokeAILogger
13+
914
from .base import (
1015
BaseModelType,
1116
DiffusersModel,
@@ -116,17 +121,20 @@ def convert_if_required(
116121
# The convert script adapted from the diffusers package uses
117122
# strings for the base model type. To avoid making too many
118123
# source code changes, we simply translate here
124+
if Path(output_path).exists():
125+
return output_path
126+
119127
if isinstance(config, cls.CheckpointConfig):
120128
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
121129

122130
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
123-
# then we bake it into the converted model.
124-
from invokeai.app.services.config import InvokeAIAppConfig
125-
126-
kwargs = dict()
131+
# then we bake it into the converted model unless there is already
132+
# a nonstandard VAE installed.
133+
kwargs = {}
127134
app_config = InvokeAIAppConfig.get_config()
128135
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
129-
if vae_path.exists():
136+
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
137+
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
130138
kwargs["vae_path"] = vae_path
131139

132140
return _convert_ckpt_and_cache(

0 commit comments

Comments
 (0)