Skip to content

Commit 82ffb58

Browse files
feat(mm): port vae to new API
1 parent 1db1264 commit 82ffb58

File tree

3 files changed

+134
-52
lines changed

3 files changed

+134
-52
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# pyright: reportIncompatibleVariableOverride=false
2424
import json
2525
import logging
26+
import re
2627
import time
2728
from abc import ABC, abstractmethod
2829
from enum import Enum
@@ -73,6 +74,15 @@ class InvalidModelConfigException(Exception):
7374
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
7475

7576

77+
def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
78+
if "_class_name" in config:
79+
return config["_class_name"]
80+
elif "architectures" in config:
81+
return config["architectures"][0]
82+
else:
83+
return None
84+
85+
7686
class SubmodelDefinition(BaseModel):
7787
path_or_prefix: str
7888
model_type: ModelType
@@ -578,18 +588,122 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
578588
}
579589

580590

581-
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
591+
class VAEConfigBase(CheckpointConfigBase):
592+
type: Literal[ModelType.VAE] = ModelType.VAE
593+
594+
595+
class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
582596
"""Model config for standalone VAE models."""
583597

584-
type: Literal[ModelType.VAE] = ModelType.VAE
598+
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
599+
600+
KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"}
601+
602+
@classmethod
603+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
604+
is_vae_override = overrides.get("type") is ModelType.VAE
605+
is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint
606+
607+
if is_vae_override and is_checkpoint_override:
608+
return MatchCertainty.OVERRIDE
609+
610+
if mod.path.is_dir():
611+
return MatchCertainty.NEVER
612+
613+
if mod.has_keys_starting_with(cls.KEY_PREFIXES):
614+
return MatchCertainty.MAYBE
615+
616+
return MatchCertainty.NEVER
617+
618+
@classmethod
619+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
620+
base = cls.get_base_type(mod)
621+
return {"base": base}
622+
623+
@classmethod
624+
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
625+
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
626+
for regexp, basetype in [
627+
(r"xl", BaseModelType.StableDiffusionXL),
628+
(r"sd2", BaseModelType.StableDiffusion2),
629+
(r"vae", BaseModelType.StableDiffusion1),
630+
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
631+
]:
632+
if re.search(regexp, mod.path.name, re.IGNORECASE):
633+
return basetype
634+
635+
raise InvalidModelConfigException("Cannot determine base type")
585636

586637

587-
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
638+
class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase):
588639
"""Model config for standalone VAE models (diffusers version)."""
589640

590-
type: Literal[ModelType.VAE] = ModelType.VAE
591641
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
592642

643+
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
644+
645+
@classmethod
646+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
647+
is_vae_override = overrides.get("type") is ModelType.VAE
648+
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
649+
650+
if is_vae_override and is_diffusers_override:
651+
return MatchCertainty.OVERRIDE
652+
653+
if mod.path.is_file():
654+
return MatchCertainty.NEVER
655+
656+
try:
657+
config = cls.get_config(mod)
658+
class_name = get_class_name_from_config(config)
659+
if class_name in cls.CLASS_NAMES:
660+
return MatchCertainty.EXACT
661+
except Exception:
662+
pass
663+
664+
return MatchCertainty.NEVER
665+
666+
@classmethod
667+
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
668+
config_path = mod.path / "config.json"
669+
with open(config_path, "r") as file:
670+
return json.load(file)
671+
672+
@classmethod
673+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
674+
base = cls.get_base_type(mod)
675+
return {"base": base}
676+
677+
@classmethod
678+
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
679+
if cls._config_looks_like_sdxl(mod):
680+
return BaseModelType.StableDiffusionXL
681+
elif cls._name_looks_like_sdxl(mod):
682+
return BaseModelType.StableDiffusionXL
683+
else:
684+
# We do not support diffusers VAEs for any other base model at this time... YOLO
685+
return BaseModelType.StableDiffusion1
686+
687+
@classmethod
688+
def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
689+
config = cls.get_config(mod)
690+
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
691+
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
692+
693+
@classmethod
694+
def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
695+
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
696+
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
697+
# we can do is guess based on name.
698+
return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
699+
700+
@classmethod
701+
def _guess_name(cls, mod: ModelOnDisk) -> str:
702+
name = mod.path.name
703+
if name == "vae":
704+
name = mod.path.parent.name
705+
return name
706+
593707

594708
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
595709
"""Model config for ControlNet models (diffusers version)."""

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import re
32
from pathlib import Path
43
from typing import Any, Callable, Dict, Literal, Optional, Union
54

@@ -654,21 +653,6 @@ def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
654653
return SchedulerPredictionType.Epsilon
655654

656655

657-
class VaeCheckpointProbe(CheckpointProbeBase):
658-
def get_base_type(self) -> BaseModelType:
659-
# VAEs of all base types have the same structure, so we wimp out and
660-
# guess using the name.
661-
for regexp, basetype in [
662-
(r"xl", BaseModelType.StableDiffusionXL),
663-
(r"sd2", BaseModelType.StableDiffusion2),
664-
(r"vae", BaseModelType.StableDiffusion1),
665-
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
666-
]:
667-
if re.search(regexp, self.model_path.name, re.IGNORECASE):
668-
return basetype
669-
raise InvalidModelConfigException("Cannot determine base type")
670-
671-
672656
class LoRACheckpointProbe(CheckpointProbeBase):
673657
"""Class for LoRA checkpoints."""
674658

@@ -895,36 +879,6 @@ def get_variant_type(self) -> ModelVariantType:
895879
return ModelVariantType.Normal
896880

897881

898-
class VaeFolderProbe(FolderProbeBase):
899-
def get_base_type(self) -> BaseModelType:
900-
if self._config_looks_like_sdxl():
901-
return BaseModelType.StableDiffusionXL
902-
elif self._name_looks_like_sdxl():
903-
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
904-
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
905-
return BaseModelType.StableDiffusionXL
906-
else:
907-
return BaseModelType.StableDiffusion1
908-
909-
def _config_looks_like_sdxl(self) -> bool:
910-
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
911-
config_file = self.model_path / "config.json"
912-
if not config_file.exists():
913-
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
914-
with open(config_file, "r") as file:
915-
config = json.load(file)
916-
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
917-
918-
def _name_looks_like_sdxl(self) -> bool:
919-
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
920-
921-
def _guess_name(self) -> str:
922-
name = self.model_path.name
923-
if name == "vae":
924-
name = self.model_path.parent.name
925-
return name
926-
927-
928882
class T5EncoderFolderProbe(FolderProbeBase):
929883
def get_base_type(self) -> BaseModelType:
930884
return BaseModelType.Any
@@ -1080,7 +1034,6 @@ def get_base_type(self) -> BaseModelType:
10801034

10811035
# Register probe classes
10821036
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
1083-
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
10841037
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
10851038
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
10861039
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
@@ -1093,7 +1046,6 @@ def get_base_type(self) -> BaseModelType:
10931046
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
10941047

10951048
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
1096-
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
10971049
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
10981050
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
10991051
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)

invokeai/backend/model_manager/model_on_disk.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,19 @@ def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
128128
f"Please specify the intended file using the 'path' argument"
129129
)
130130
return path
131+
132+
def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool:
133+
state_dict = self.load_state_dict(path)
134+
return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
135+
136+
def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
137+
state_dict = self.load_state_dict(path)
138+
return any(
139+
any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str)
140+
)
141+
142+
def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
143+
state_dict = self.load_state_dict(path)
144+
return any(
145+
any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str)
146+
)

0 commit comments

Comments
 (0)