Skip to content

Commit 1494487

Browse files
psychedeliciousmaryhipp
authored andcommitted
feat(mm): add model taxonomy for API models & Imagen3 as base model type
1 parent 07bcf3c commit 1494487

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,21 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
600600
}
601601

602602

603+
class ApiModelConfig(MainConfigBase, ModelConfigBase):
604+
"""Model config for API-based models."""
605+
606+
format: Literal[ModelFormat.Api] = ModelFormat.Api
607+
608+
@classmethod
609+
def matches(cls, mod: ModelOnDisk) -> bool:
610+
# API models are not stored on disk, so we can't match them.
611+
return False
612+
613+
@classmethod
614+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
615+
raise NotImplementedError("API models are not parsed from disk.")
616+
617+
603618
def get_model_discriminator_value(v: Any) -> str:
604619
"""
605620
Computes the discriminator value for a model config.
@@ -667,6 +682,7 @@ def get_model_discriminator_value(v: Any) -> str:
667682
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
668683
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
669684
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
685+
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
670686
],
671687
Discriminator(get_model_discriminator_value),
672688
]

invokeai/backend/model_manager/taxonomy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BaseModelType(str, Enum):
2626
StableDiffusionXLRefiner = "sdxl-refiner"
2727
Flux = "flux"
2828
CogView4 = "cogview4"
29-
# Kandinsky2_1 = "kandinsky-2.1"
29+
Imagen3 = "imagen3"
3030

3131

3232
class ModelType(str, Enum):
@@ -98,6 +98,7 @@ class ModelFormat(str, Enum):
9898
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
9999
BnbQuantizednf4b = "bnb_quantized_nf4b"
100100
GGUFQuantized = "gguf_quantized"
101+
Api = "api"
101102

102103

103104
class SchedulerPredictionType(str, Enum):

0 commit comments

Comments
 (0)