Skip to content

Commit 3f3f941

Browse files
feat(mm): add UnknownModelConfig
1 parent 4d585e3 commit 3f3f941

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from enum import Enum
2929
from inspect import isabstract
3030
from pathlib import Path
31-
from typing import ClassVar, Literal, Optional, TypeAlias, Union
31+
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
3232

3333
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
3434
from typing_extensions import Annotated, Any, Dict
@@ -109,6 +109,18 @@ class MatchSpeed(int, Enum):
109109
SLOW = 2
110110

111111

112+
class LegacyProbeMixin:
113+
"""Mixin for classes using the legacy probe for model classification."""
114+
115+
@classmethod
116+
def matches(cls, *args, **kwargs):
117+
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
118+
119+
@classmethod
120+
def parse(cls, *args, **kwargs):
121+
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
122+
123+
112124
class ModelConfigBase(ABC, BaseModel):
113125
"""
114126
Abstract Base class for model configurations.
@@ -152,15 +164,15 @@ def json_schema_extra(schema: dict[str, Any]) -> None:
152164
)
153165
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
154166

155-
USING_LEGACY_PROBE: ClassVar[set] = set()
156-
USING_CLASSIFY_API: ClassVar[set] = set()
167+
USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set()
168+
USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set()
157169
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
158170

159171
def __init_subclass__(cls, **kwargs):
160172
super().__init_subclass__(**kwargs)
161173
if issubclass(cls, LegacyProbeMixin):
162174
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
163-
else:
175+
elif cls is not UnknownModelConfig:
164176
ModelConfigBase.USING_CLASSIFY_API.add(cls)
165177

166178
@staticmethod
@@ -170,7 +182,9 @@ def all_config_classes():
170182
return concrete
171183

172184
@staticmethod
173-
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
185+
def classify(
186+
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
187+
) -> "AnyModelConfig":
174188
"""
175189
Returns the best matching ModelConfig instance from a model's file/folder path.
176190
Raises InvalidModelConfigException if no valid configuration is found.
@@ -192,7 +206,10 @@ def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "bla
192206
else:
193207
return config_cls.from_model_on_disk(mod, **overrides)
194208

195-
raise InvalidModelConfigException("Unable to determine model type")
209+
try:
210+
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
211+
except Exception:
212+
raise InvalidModelConfigException("Unable to determine model type")
196213

197214
@classmethod
198215
def get_tag(cls) -> Tag:
@@ -256,16 +273,17 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
256273
return cls(**fields)
257274

258275

259-
class LegacyProbeMixin:
260-
"""Mixin for classes using the legacy probe for model classification."""
276+
class UnknownModelConfig(ModelConfigBase):
277+
type: Literal[ModelType.Unknown] = ModelType.Unknown
278+
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
261279

262280
@classmethod
263-
def matches(cls, *args, **kwargs):
264-
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
281+
def matches(cls, *args, **kwargs) -> bool:
282+
raise NotImplementedError("UnknownModelConfig cannot match anything")
265283

266284
@classmethod
267-
def parse(cls, *args, **kwargs):
268-
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
285+
def parse(cls, *args, **kwargs) -> dict[str, Any]:
286+
raise NotImplementedError("UnknownModelConfig cannot parse anything")
269287

270288

271289
class CheckpointConfigBase(ABC, BaseModel):
@@ -353,7 +371,7 @@ def matches(cls, mod: ModelOnDisk) -> bool:
353371

354372
metadata = mod.metadata()
355373
return (
356-
metadata.get("modelspec.sai_model_spec")
374+
bool(metadata.get("modelspec.sai_model_spec"))
357375
and metadata.get("ot_branch") == "omi_format"
358376
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
359377
)
@@ -751,6 +769,7 @@ def get_model_discriminator_value(v: Any) -> str:
751769
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
752770
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
753771
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
772+
Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()],
754773
],
755774
Discriminator(get_model_discriminator_value),
756775
]

invokeai/backend/model_manager/taxonomy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ModelType(str, Enum):
5555
FluxRedux = "flux_redux"
5656
LlavaOnevision = "llava_onevision"
5757
Video = "video"
58+
Unknown = "unknown"
5859

5960

6061
class SubModelType(str, Enum):
@@ -107,6 +108,7 @@ class ModelFormat(str, Enum):
107108
BnbQuantizednf4b = "bnb_quantized_nf4b"
108109
GGUFQuantized = "gguf_quantized"
109110
Api = "api"
111+
Unknown = "unknown"
110112

111113

112114
class SchedulerPredictionType(str, Enum):

0 commit comments

Comments
 (0)