|
23 | 23 | # pyright: reportIncompatibleVariableOverride=false
|
24 | 24 | import json
|
25 | 25 | import logging
|
| 26 | +import re |
26 | 27 | import time
|
27 | 28 | from abc import ABC, abstractmethod
|
28 | 29 | from enum import Enum
|
@@ -73,6 +74,15 @@ class InvalidModelConfigException(Exception):
|
73 | 74 | DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
74 | 75 |
|
75 | 76 |
|
| 77 | +def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]: |
| 78 | + if "_class_name" in config: |
| 79 | + return config["_class_name"] |
| 80 | + elif "architectures" in config: |
| 81 | + return config["architectures"][0] |
| 82 | + else: |
| 83 | + return None |
| 84 | + |
| 85 | + |
76 | 86 | class SubmodelDefinition(BaseModel):
|
77 | 87 | path_or_prefix: str
|
78 | 88 | model_type: ModelType
|
@@ -578,18 +588,122 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
578 | 588 | }
|
579 | 589 |
|
580 | 590 |
|
581 |
| -class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase): |
| 591 | +class VAEConfigBase(CheckpointConfigBase): |
| 592 | + type: Literal[ModelType.VAE] = ModelType.VAE |
| 593 | + |
| 594 | + |
| 595 | +class VAECheckpointConfig(VAEConfigBase, ModelConfigBase): |
582 | 596 | """Model config for standalone VAE models."""
|
583 | 597 |
|
584 |
| - type: Literal[ModelType.VAE] = ModelType.VAE |
| 598 | + format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint |
| 599 | + |
| 600 | + KEY_PREFIXES: ClassVar = {"encoder.conv_in", "decoder.conv_in"} |
| 601 | + |
| 602 | + @classmethod |
| 603 | + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: |
| 604 | + is_vae_override = overrides.get("type") is ModelType.VAE |
| 605 | + is_checkpoint_override = overrides.get("format") is ModelFormat.Checkpoint |
| 606 | + |
| 607 | + if is_vae_override and is_checkpoint_override: |
| 608 | + return MatchCertainty.OVERRIDE |
| 609 | + |
| 610 | + if mod.path.is_dir(): |
| 611 | + return MatchCertainty.NEVER |
| 612 | + |
| 613 | + if mod.has_keys_starting_with(cls.KEY_PREFIXES): |
| 614 | + return MatchCertainty.MAYBE |
| 615 | + |
| 616 | + return MatchCertainty.NEVER |
| 617 | + |
| 618 | + @classmethod |
| 619 | + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: |
| 620 | + base = cls.get_base_type(mod) |
| 621 | + return {"base": base} |
| 622 | + |
| 623 | + @classmethod |
| 624 | + def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: |
| 625 | + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name |
| 626 | + for regexp, basetype in [ |
| 627 | + (r"xl", BaseModelType.StableDiffusionXL), |
| 628 | + (r"sd2", BaseModelType.StableDiffusion2), |
| 629 | + (r"vae", BaseModelType.StableDiffusion1), |
| 630 | + (r"FLUX.1-schnell_ae", BaseModelType.Flux), |
| 631 | + ]: |
| 632 | + if re.search(regexp, mod.path.name, re.IGNORECASE): |
| 633 | + return basetype |
| 634 | + |
| 635 | + raise InvalidModelConfigException("Cannot determine base type") |
585 | 636 |
|
586 | 637 |
|
587 |
| -class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase): |
| 638 | +class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase): |
588 | 639 | """Model config for standalone VAE models (diffusers version)."""
|
589 | 640 |
|
590 |
| - type: Literal[ModelType.VAE] = ModelType.VAE |
591 | 641 | format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
592 | 642 |
|
| 643 | + CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"} |
| 644 | + |
| 645 | + @classmethod |
| 646 | + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: |
| 647 | + is_vae_override = overrides.get("type") is ModelType.VAE |
| 648 | + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers |
| 649 | + |
| 650 | + if is_vae_override and is_diffusers_override: |
| 651 | + return MatchCertainty.OVERRIDE |
| 652 | + |
| 653 | + if mod.path.is_file(): |
| 654 | + return MatchCertainty.NEVER |
| 655 | + |
| 656 | + try: |
| 657 | + config = cls.get_config(mod) |
| 658 | + class_name = get_class_name_from_config(config) |
| 659 | + if class_name in cls.CLASS_NAMES: |
| 660 | + return MatchCertainty.EXACT |
| 661 | + except Exception: |
| 662 | + pass |
| 663 | + |
| 664 | + return MatchCertainty.NEVER |
| 665 | + |
| 666 | + @classmethod |
| 667 | + def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]: |
| 668 | + config_path = mod.path / "config.json" |
| 669 | + with open(config_path, "r") as file: |
| 670 | + return json.load(file) |
| 671 | + |
| 672 | + @classmethod |
| 673 | + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: |
| 674 | + base = cls.get_base_type(mod) |
| 675 | + return {"base": base} |
| 676 | + |
| 677 | + @classmethod |
| 678 | + def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType: |
| 679 | + if cls._config_looks_like_sdxl(mod): |
| 680 | + return BaseModelType.StableDiffusionXL |
| 681 | + elif cls._name_looks_like_sdxl(mod): |
| 682 | + return BaseModelType.StableDiffusionXL |
| 683 | + else: |
| 684 | + # We do not support diffusers VAEs for any other base model at this time... YOLO |
| 685 | + return BaseModelType.StableDiffusion1 |
| 686 | + |
| 687 | + @classmethod |
| 688 | + def _config_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: |
| 689 | + config = cls.get_config(mod) |
| 690 | + # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. |
| 691 | + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] |
| 692 | + |
| 693 | + @classmethod |
| 694 | + def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: |
| 695 | + # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down |
| 696 | + # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best |
| 697 | + # we can do is guess based on name. |
| 698 | + return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) |
| 699 | + |
| 700 | + @classmethod |
| 701 | + def _guess_name(cls, mod: ModelOnDisk) -> str: |
| 702 | + name = mod.path.name |
| 703 | + if name == "vae": |
| 704 | + name = mod.path.parent.name |
| 705 | + return name |
| 706 | + |
593 | 707 |
|
594 | 708 | class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
595 | 709 | """Model config for ControlNet models (diffusers version)."""
|
|
0 commit comments