|
30 | 30 | from pathlib import Path
|
31 | 31 | from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
|
32 | 32 |
|
| 33 | +import torch |
33 | 34 | from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
34 | 35 | from typing_extensions import Annotated, Any, Dict
|
35 | 36 |
|
@@ -601,19 +602,137 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
601 | 602 | type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
602 | 603 |
|
603 | 604 |
|
604 |
| -class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase): |
| 605 | +class TextualInversionConfigBase(ABC, BaseModel): |
| 606 | + type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion |
| 607 | + |
| 608 | + @classmethod |
| 609 | + def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: |
| 610 | + p = path or mod.path |
| 611 | + |
| 612 | + if not p.exists(): |
| 613 | + return False |
| 614 | + |
| 615 | + if p.is_dir(): |
| 616 | + return False |
| 617 | + |
| 618 | + if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}: |
| 619 | + return True |
| 620 | + |
| 621 | + state_dict = mod.load_state_dict(p) |
| 622 | + |
| 623 | + # Heuristic: textual inversion embeddings have these keys |
| 624 | + if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()): |
| 625 | + return True |
| 626 | + |
| 627 | + # Heuristic: small state dict with all tensor values |
| 628 | + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): |
| 629 | + return True |
| 630 | + |
| 631 | + return False |
| 632 | + |
| 633 | + @classmethod |
| 634 | + def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: |
| 635 | + p = path or mod.path |
| 636 | + |
| 637 | + try: |
| 638 | + state_dict = mod.load_state_dict(p) |
| 639 | + if "string_to_token" in state_dict: |
| 640 | + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] |
| 641 | + elif "emb_params" in state_dict: |
| 642 | + token_dim = state_dict["emb_params"].shape[-1] |
| 643 | + elif "clip_g" in state_dict: |
| 644 | + token_dim = state_dict["clip_g"].shape[-1] |
| 645 | + else: |
| 646 | + token_dim = list(state_dict.values())[0].shape[0] |
| 647 | + |
| 648 | + match token_dim: |
| 649 | + case 768: |
| 650 | + return BaseModelType.StableDiffusion1 |
| 651 | + case 1024: |
| 652 | + return BaseModelType.StableDiffusion2 |
| 653 | + case 1280: |
| 654 | + return BaseModelType.StableDiffusionXL |
| 655 | + case _: |
| 656 | + pass |
| 657 | + except Exception: |
| 658 | + pass |
| 659 | + |
| 660 | + raise InvalidModelConfigException(f"{p}: Could not determine base type") |
| 661 | + |
| 662 | + |
| 663 | +class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): |
605 | 664 | """Model config for textual inversion embeddings."""
|
606 | 665 |
|
607 |
| - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion |
608 | 666 | format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
609 | 667 |
|
| 668 | + @classmethod |
| 669 | + def get_tag(cls) -> Tag: |
| 670 | + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") |
| 671 | + |
| 672 | + @classmethod |
| 673 | + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: |
| 674 | + is_embedding_override = overrides.get("type") is ModelType.TextualInversion |
| 675 | + is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile |
| 676 | + |
| 677 | + if is_embedding_override and is_file_override: |
| 678 | + return MatchCertainty.OVERRIDE |
| 679 | + |
| 680 | + if mod.path.is_dir(): |
| 681 | + return MatchCertainty.NEVER |
| 682 | + |
| 683 | + if cls.file_looks_like_embedding(mod): |
| 684 | + return MatchCertainty.MAYBE |
| 685 | + |
| 686 | + return MatchCertainty.NEVER |
| 687 | + |
| 688 | + @classmethod |
| 689 | + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: |
| 690 | + try: |
| 691 | + base = cls.get_base(mod) |
| 692 | + return {"base": base} |
| 693 | + except Exception: |
| 694 | + pass |
610 | 695 |
|
611 |
| -class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase): |
| 696 | + raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") |
| 697 | + |
| 698 | + |
| 699 | +class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): |
612 | 700 | """Model config for textual inversion embeddings."""
|
613 | 701 |
|
614 |
| - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion |
615 | 702 | format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
616 | 703 |
|
| 704 | + @classmethod |
| 705 | + def get_tag(cls) -> Tag: |
| 706 | + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") |
| 707 | + |
| 708 | + @classmethod |
| 709 | + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: |
| 710 | + is_embedding_override = overrides.get("type") is ModelType.TextualInversion |
| 711 | + is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder |
| 712 | + |
| 713 | + if is_embedding_override and is_folder_override: |
| 714 | + return MatchCertainty.OVERRIDE |
| 715 | + |
| 716 | + if mod.path.is_file(): |
| 717 | + return MatchCertainty.NEVER |
| 718 | + |
| 719 | + for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: |
| 720 | + if cls.file_looks_like_embedding(mod, mod.path / filename): |
| 721 | + return MatchCertainty.MAYBE |
| 722 | + |
| 723 | + return MatchCertainty.NEVER |
| 724 | + |
| 725 | + @classmethod |
| 726 | + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: |
| 727 | + try: |
| 728 | + for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: |
| 729 | + base = cls.get_base(mod, mod.path / filename) |
| 730 | + return {"base": base} |
| 731 | + except Exception: |
| 732 | + pass |
| 733 | + |
| 734 | + raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") |
| 735 | + |
617 | 736 |
|
618 | 737 | class MainConfigBase(ABC, BaseModel):
|
619 | 738 | type: Literal[ModelType.Main] = ModelType.Main
|
|
0 commit comments