|
30 | 30 | from pathlib import Path
|
31 | 31 | from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
|
32 | 32 |
|
| 33 | +import spandrel |
33 | 34 | import torch
|
34 | 35 | from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
35 | 36 | from typing_extensions import Annotated, Any, Dict
|
|
56 | 57 | variant_type_adapter,
|
57 | 58 | )
|
58 | 59 | from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
| 60 | +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel |
59 | 61 | from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
60 | 62 |
|
61 | 63 | logger = logging.getLogger(__name__)
|
@@ -605,30 +607,36 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
605 | 607 | class TextualInversionConfigBase(ABC, BaseModel):
|
606 | 608 | type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
607 | 609 |
|
| 610 | + KNOWN_SUFFIXES: ClassVar = {"bin", "safetensors", "pt", "ckpt"} |
| 611 | + KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"} |
| 612 | + |
608 | 613 | @classmethod
|
609 | 614 | def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
610 |
| - p = path or mod.path |
| 615 | + try: |
| 616 | + p = path or mod.path |
611 | 617 |
|
612 |
| - if not p.exists(): |
613 |
| - return False |
| 618 | + if not p.exists(): |
| 619 | + return False |
614 | 620 |
|
615 |
| - if p.is_dir(): |
616 |
| - return False |
| 621 | + if p.is_dir(): |
| 622 | + return False |
617 | 623 |
|
618 |
| - if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}: |
619 |
| - return True |
| 624 | + if p.name in [f"learned_embeds.{s}" for s in cls.KNOWN_SUFFIXES]: |
| 625 | + return True |
620 | 626 |
|
621 |
| - state_dict = mod.load_state_dict(p) |
| 627 | + state_dict = mod.load_state_dict(p) |
622 | 628 |
|
623 |
| - # Heuristic: textual inversion embeddings have these keys |
624 |
| - if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()): |
625 |
| - return True |
| 629 | + # Heuristic: textual inversion embeddings have these keys |
| 630 | + if any(key in cls.KNOWN_KEYS for key in state_dict.keys()): |
| 631 | + return True |
626 | 632 |
|
627 |
| - # Heuristic: small state dict with all tensor values |
628 |
| - if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): |
629 |
| - return True |
| 633 | + # Heuristic: small state dict with all tensor values |
| 634 | + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): |
| 635 | + return True |
630 | 636 |
|
631 |
| - return False |
| 637 | + return False |
| 638 | + except Exception: |
| 639 | + return False |
632 | 640 |
|
633 | 641 | @classmethod
|
634 | 642 | def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
@@ -716,8 +724,8 @@ def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
716 | 724 | if mod.path.is_file():
|
717 | 725 | return MatchCertainty.NEVER
|
718 | 726 |
|
719 |
| - for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: |
720 |
| - if cls.file_looks_like_embedding(mod, mod.path / filename): |
| 727 | + for p in mod.path.iterdir(): |
| 728 | + if cls.file_looks_like_embedding(mod, p): |
721 | 729 | return MatchCertainty.MAYBE
|
722 | 730 |
|
723 | 731 | return MatchCertainty.NEVER
|
@@ -929,14 +937,39 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
|
929 | 937 | format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
930 | 938 |
|
931 | 939 |
|
932 |
| -class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase): |
| 940 | +class SpandrelImageToImageConfig(ModelConfigBase): |
933 | 941 | """Model config for Spandrel Image to Image models."""
|
934 | 942 |
|
935 | 943 | _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
|
936 | 944 |
|
937 | 945 | type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
938 | 946 | format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
939 | 947 |
|
| 948 | + @classmethod |
| 949 | + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: |
| 950 | + if not mod.path.is_file(): |
| 951 | + return MatchCertainty.NEVER |
| 952 | + |
| 953 | + try: |
| 954 | + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were |
| 955 | + # explored to avoid this: |
| 956 | + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta |
| 957 | + # device. Unfortunately, some Spandrel models perform operations during initialization that are not |
| 958 | + # supported on meta tensors. |
| 959 | + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. |
| 960 | + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to |
| 961 | + # maintain it, and the risk of false positive detections is higher. |
| 962 | + SpandrelImageToImageModel.load_from_file(mod.path) |
| 963 | + return MatchCertainty.EXACT |
| 964 | + except spandrel.UnsupportedModelError: |
| 965 | + pass |
| 966 | + except Exception as e: |
| 967 | + logger.warning( |
| 968 | + f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}" |
| 969 | + ) |
| 970 | + |
| 971 | + return MatchCertainty.NEVER |
| 972 | + |
940 | 973 |
|
941 | 974 | class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
942 | 975 | """Model config for SigLIP."""
|
|
0 commit comments