Skip to content

Commit 8ae9716

Browse files
feat(mm): port TIs to new API
1 parent 7c72824 commit 8ae9716

File tree

1 file changed

+123
-4
lines changed

1 file changed

+123
-4
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pathlib import Path
3131
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
3232

33+
import torch
3334
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
3435
from typing_extensions import Annotated, Any, Dict
3536

@@ -601,19 +602,137 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
601602
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
602603

603604

604-
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
605+
class TextualInversionConfigBase(ABC, BaseModel):
606+
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
607+
608+
@classmethod
609+
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
610+
p = path or mod.path
611+
612+
if not p.exists():
613+
return False
614+
615+
if p.is_dir():
616+
return False
617+
618+
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
619+
return True
620+
621+
state_dict = mod.load_state_dict(p)
622+
623+
# Heuristic: textual inversion embeddings have these keys
624+
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
625+
return True
626+
627+
# Heuristic: small state dict with all tensor values
628+
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
629+
return True
630+
631+
return False
632+
633+
@classmethod
634+
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
635+
p = path or mod.path
636+
637+
try:
638+
state_dict = mod.load_state_dict(p)
639+
if "string_to_token" in state_dict:
640+
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
641+
elif "emb_params" in state_dict:
642+
token_dim = state_dict["emb_params"].shape[-1]
643+
elif "clip_g" in state_dict:
644+
token_dim = state_dict["clip_g"].shape[-1]
645+
else:
646+
token_dim = list(state_dict.values())[0].shape[0]
647+
648+
match token_dim:
649+
case 768:
650+
return BaseModelType.StableDiffusion1
651+
case 1024:
652+
return BaseModelType.StableDiffusion2
653+
case 1280:
654+
return BaseModelType.StableDiffusionXL
655+
case _:
656+
pass
657+
except Exception:
658+
pass
659+
660+
raise InvalidModelConfigException(f"{p}: Could not determine base type")
661+
662+
663+
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
605664
"""Model config for textual inversion embeddings."""
606665

607-
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
608666
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
609667

668+
@classmethod
669+
def get_tag(cls) -> Tag:
670+
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
671+
672+
@classmethod
673+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
674+
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
675+
is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile
676+
677+
if is_embedding_override and is_file_override:
678+
return MatchCertainty.OVERRIDE
679+
680+
if mod.path.is_dir():
681+
return MatchCertainty.NEVER
682+
683+
if cls.file_looks_like_embedding(mod):
684+
return MatchCertainty.MAYBE
685+
686+
return MatchCertainty.NEVER
687+
688+
@classmethod
689+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
690+
try:
691+
base = cls.get_base(mod)
692+
return {"base": base}
693+
except Exception:
694+
pass
610695

611-
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
696+
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
697+
698+
699+
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
612700
"""Model config for textual inversion embeddings."""
613701

614-
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
615702
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
616703

704+
@classmethod
705+
def get_tag(cls) -> Tag:
706+
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
707+
708+
@classmethod
709+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
710+
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
711+
is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder
712+
713+
if is_embedding_override and is_folder_override:
714+
return MatchCertainty.OVERRIDE
715+
716+
if mod.path.is_file():
717+
return MatchCertainty.NEVER
718+
719+
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
720+
if cls.file_looks_like_embedding(mod, mod.path / filename):
721+
return MatchCertainty.MAYBE
722+
723+
return MatchCertainty.NEVER
724+
725+
@classmethod
726+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
727+
try:
728+
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
729+
base = cls.get_base(mod, mod.path / filename)
730+
return {"base": base}
731+
except Exception:
732+
pass
733+
734+
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
735+
617736

618737
class MainConfigBase(ABC, BaseModel):
619738
type: Literal[ModelType.Main] = ModelType.Main

0 commit comments

Comments
 (0)