@@ -588,11 +588,11 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
588
588
}
589
589
590
590
591
- class VAEConfigBase (CheckpointConfigBase ):
591
+ class VAEConfigBase (ABC , BaseModel ):
592
592
type : Literal [ModelType .VAE ] = ModelType .VAE
593
593
594
594
595
- class VAECheckpointConfig (VAEConfigBase , ModelConfigBase ):
595
+ class VAECheckpointConfig (VAEConfigBase , CheckpointConfigBase , ModelConfigBase ):
596
596
"""Model config for standalone VAE models."""
597
597
598
598
format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
@@ -618,7 +618,20 @@ def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
618
618
@classmethod
619
619
def parse (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
620
620
base = cls .get_base_type (mod )
621
- return {"base" : base }
621
+ config_path = (
622
+ # For flux, this is a key in invokeai.backend.flux.util.ae_params
623
+ # Due to model type and format being the descriminator for model configs this
624
+ # is used rather than attempting to support flux with separate model types and format
625
+ # If changed in the future, please fix me
626
+ "flux"
627
+ if base is BaseModelType .Flux
628
+ else "stable-diffusion/v1-inference.yaml"
629
+ if base is BaseModelType .StableDiffusion1
630
+ else "stable-diffusion/sd_xl_base.yaml"
631
+ if base is BaseModelType .StableDiffusionXL
632
+ else "stable-diffusion/v2-inference.yaml"
633
+ )
634
+ return {"base" : base , "config_path" : config_path }
622
635
623
636
@classmethod
624
637
def get_base_type (cls , mod : ModelOnDisk ) -> BaseModelType :
@@ -635,7 +648,7 @@ def get_base_type(cls, mod: ModelOnDisk) -> BaseModelType:
635
648
raise InvalidModelConfigException ("Cannot determine base type" )
636
649
637
650
638
- class VAEDiffusersConfig (VAEConfigBase , ModelConfigBase ):
651
+ class VAEDiffusersConfig (VAEConfigBase , DiffusersConfigBase , ModelConfigBase ):
639
652
"""Model config for standalone VAE models (diffusers version)."""
640
653
641
654
format : Literal [ModelFormat .Diffusers ] = ModelFormat .Diffusers
0 commit comments