Skip to content

Commit cdcdecc

Browse files
feat(mm): port spandrel to new API
1 parent 8b6fe5c commit cdcdecc

File tree

2 files changed

+51
-42
lines changed

2 files changed

+51
-42
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pathlib import Path
3131
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
3232

33+
import spandrel
3334
import torch
3435
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
3536
from typing_extensions import Annotated, Any, Dict
@@ -56,6 +57,7 @@
5657
variant_type_adapter,
5758
)
5859
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
60+
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
5961
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
6062

6163
logger = logging.getLogger(__name__)
@@ -605,30 +607,36 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
605607
class TextualInversionConfigBase(ABC, BaseModel):
606608
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
607609

610+
KNOWN_SUFFIXES: ClassVar = {"bin", "safetensors", "pt", "ckpt"}
611+
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
612+
608613
@classmethod
609614
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
610-
p = path or mod.path
615+
try:
616+
p = path or mod.path
611617

612-
if not p.exists():
613-
return False
618+
if not p.exists():
619+
return False
614620

615-
if p.is_dir():
616-
return False
621+
if p.is_dir():
622+
return False
617623

618-
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
619-
return True
624+
if p.name in [f"learned_embeds.{s}" for s in cls.KNOWN_SUFFIXES]:
625+
return True
620626

621-
state_dict = mod.load_state_dict(p)
627+
state_dict = mod.load_state_dict(p)
622628

623-
# Heuristic: textual inversion embeddings have these keys
624-
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
625-
return True
629+
# Heuristic: textual inversion embeddings have these keys
630+
if any(key in cls.KNOWN_KEYS for key in state_dict.keys()):
631+
return True
626632

627-
# Heuristic: small state dict with all tensor values
628-
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
629-
return True
633+
# Heuristic: small state dict with all tensor values
634+
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
635+
return True
630636

631-
return False
637+
return False
638+
except Exception:
639+
return False
632640

633641
@classmethod
634642
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
@@ -716,8 +724,8 @@ def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
716724
if mod.path.is_file():
717725
return MatchCertainty.NEVER
718726

719-
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
720-
if cls.file_looks_like_embedding(mod, mod.path / filename):
727+
for p in mod.path.iterdir():
728+
if cls.file_looks_like_embedding(mod, p):
721729
return MatchCertainty.MAYBE
722730

723731
return MatchCertainty.NEVER
@@ -929,14 +937,39 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
929937
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
930938

931939

932-
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
940+
class SpandrelImageToImageConfig(ModelConfigBase):
933941
"""Model config for Spandrel Image to Image models."""
934942

935943
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
936944

937945
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
938946
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
939947

948+
@classmethod
949+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
950+
if not mod.path.is_file():
951+
return MatchCertainty.NEVER
952+
953+
try:
954+
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
955+
# explored to avoid this:
956+
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
957+
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
958+
# supported on meta tensors.
959+
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
960+
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
961+
# maintain it, and the risk of false positive detections is higher.
962+
SpandrelImageToImageModel.load_from_file(mod.path)
963+
return MatchCertainty.EXACT
964+
except spandrel.UnsupportedModelError:
965+
pass
966+
except Exception as e:
967+
logger.warning(
968+
f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}"
969+
)
970+
971+
return MatchCertainty.NEVER
972+
940973

941974
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
942975
"""Model config for SigLIP."""

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import picklescan.scanner as pscan
77
import safetensors.torch
8-
import spandrel
98
import torch
109

1110
import invokeai.backend.util.logging as logger
@@ -59,7 +58,6 @@
5958
)
6059
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
6160
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
62-
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
6361
from invokeai.backend.util.silence_warnings import SilenceWarnings
6462

6563
CkptType = Dict[str | int, Any]
@@ -340,26 +338,6 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
340338
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
341339
return ModelType.TextualInversion
342340

343-
# Check if the model can be loaded as a SpandrelImageToImageModel.
344-
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
345-
try:
346-
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
347-
# explored to avoid this:
348-
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
349-
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
350-
# supported on meta tensors.
351-
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
352-
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
353-
# maintain it, and the risk of false positive detections is higher.
354-
SpandrelImageToImageModel.load_from_file(model_path)
355-
return ModelType.SpandrelImageToImage
356-
except spandrel.UnsupportedModelError:
357-
pass
358-
except Exception as e:
359-
logger.warning(
360-
f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}"
361-
)
362-
363341
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
364342

365343
@classmethod
@@ -1110,7 +1088,6 @@ def get_base_type(self) -> BaseModelType:
11101088
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
11111089
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
11121090
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
1113-
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
11141091
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
11151092
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
11161093
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
@@ -1123,7 +1100,6 @@ def get_base_type(self) -> BaseModelType:
11231100
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
11241101
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
11251102
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
1126-
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
11271103
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
11281104
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
11291105
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)

0 commit comments

Comments
 (0)