Skip to content

Commit 20a0231

Browse files
fix(mm): vae class inheritance and config_path
1 parent 82ffb58 commit 20a0231

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,11 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
588588
}
589589

590590

591-
class VAEConfigBase(CheckpointConfigBase):
591+
class VAEConfigBase(ABC, BaseModel):
592592
type: Literal[ModelType.VAE] = ModelType.VAE
593593

594594

595-
class VAECheckpointConfig(VAEConfigBase, ModelConfigBase):
595+
class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
596596
"""Model config for standalone VAE models."""
597597

598598
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@@ -618,7 +618,20 @@ def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
618618
@classmethod
619619
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
620620
base = cls.get_base_type(mod)
621-
return {"base": base}
621+
config_path = (
622+
# For flux, this is a key in invokeai.backend.flux.util.ae_params
623+
# Due to model type and format being the descriminator for model configs this
624+
# is used rather than attempting to support flux with separate model types and format
625+
# If changed in the future, please fix me
626+
"flux"
627+
if base is BaseModelType.Flux
628+
else "stable-diffusion/v1-inference.yaml"
629+
if base is BaseModelType.StableDiffusion1
630+
else "stable-diffusion/sd_xl_base.yaml"
631+
if base is BaseModelType.StableDiffusionXL
632+
else "stable-diffusion/v2-inference.yaml"
633+
)
634+
return {"base": base, "config_path": config_path}
622635

623636
@classmethod
624637
def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
@@ -635,7 +648,7 @@ def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
635648
raise InvalidModelConfigException("Cannot determine base type")
636649

637650

638-
class VAEDiffusersConfig(VAEConfigBase, ModelConfigBase):
651+
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
639652
"""Model config for standalone VAE models (diffusers version)."""
640653

641654
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

0 commit comments

Comments
 (0)