28
28
from enum import Enum
29
29
from inspect import isabstract
30
30
from pathlib import Path
31
- from typing import ClassVar , Literal , Optional , TypeAlias , Union
31
+ from typing import ClassVar , Literal , Optional , Type , TypeAlias , Union
32
32
33
33
from pydantic import BaseModel , ConfigDict , Discriminator , Field , Tag , TypeAdapter
34
34
from typing_extensions import Annotated , Any , Dict
@@ -109,6 +109,18 @@ class MatchSpeed(int, Enum):
109
109
SLOW = 2
110
110
111
111
112
+ class LegacyProbeMixin :
113
+ """Mixin for classes using the legacy probe for model classification."""
114
+
115
+ @classmethod
116
+ def matches (cls , * args , ** kwargs ):
117
+ raise NotImplementedError (f"Method 'matches' not implemented for { cls .__name__ } " )
118
+
119
+ @classmethod
120
+ def parse (cls , * args , ** kwargs ):
121
+ raise NotImplementedError (f"Method 'parse' not implemented for { cls .__name__ } " )
122
+
123
+
112
124
class ModelConfigBase (ABC , BaseModel ):
113
125
"""
114
126
Abstract Base class for model configurations.
@@ -152,15 +164,15 @@ def json_schema_extra(schema: dict[str, Any]) -> None:
152
164
)
153
165
usage_info : Optional [str ] = Field (default = None , description = "Usage information for this model" )
154
166
155
- USING_LEGACY_PROBE : ClassVar [set ] = set ()
156
- USING_CLASSIFY_API : ClassVar [set ] = set ()
167
+ USING_LEGACY_PROBE : ClassVar [set [ Type [ "ModelConfigBase" ]] ] = set ()
168
+ USING_CLASSIFY_API : ClassVar [set [ Type [ "ModelConfigBase" ]] ] = set ()
157
169
_MATCH_SPEED : ClassVar [MatchSpeed ] = MatchSpeed .MED
158
170
159
171
def __init_subclass__ (cls , ** kwargs ):
160
172
super ().__init_subclass__ (** kwargs )
161
173
if issubclass (cls , LegacyProbeMixin ):
162
174
ModelConfigBase .USING_LEGACY_PROBE .add (cls )
163
- else :
175
+ elif cls is not UnknownModelConfig :
164
176
ModelConfigBase .USING_CLASSIFY_API .add (cls )
165
177
166
178
@staticmethod
@@ -170,7 +182,9 @@ def all_config_classes():
170
182
return concrete
171
183
172
184
@staticmethod
173
- def classify (mod : str | Path | ModelOnDisk , hash_algo : HASHING_ALGORITHMS = "blake3_single" , ** overrides ):
185
+ def classify (
186
+ mod : str | Path | ModelOnDisk , hash_algo : HASHING_ALGORITHMS = "blake3_single" , ** overrides
187
+ ) -> "AnyModelConfig" :
174
188
"""
175
189
Returns the best matching ModelConfig instance from a model's file/folder path.
176
190
Raises InvalidModelConfigException if no valid configuration is found.
@@ -192,7 +206,10 @@ def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "bla
192
206
else :
193
207
return config_cls .from_model_on_disk (mod , ** overrides )
194
208
195
- raise InvalidModelConfigException ("Unable to determine model type" )
209
+ try :
210
+ return UnknownModelConfig .from_model_on_disk (mod , ** overrides )
211
+ except Exception :
212
+ raise InvalidModelConfigException ("Unable to determine model type" )
196
213
197
214
@classmethod
198
215
def get_tag (cls ) -> Tag :
@@ -256,16 +273,17 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
256
273
return cls (** fields )
257
274
258
275
259
- class LegacyProbeMixin :
260
- """Mixin for classes using the legacy probe for model classification."""
276
+ class UnknownModelConfig (ModelConfigBase ):
277
+ type : Literal [ModelType .Unknown ] = ModelType .Unknown
278
+ format : Literal [ModelFormat .Unknown ] = ModelFormat .Unknown
261
279
262
280
@classmethod
263
- def matches (cls , * args , ** kwargs ):
264
- raise NotImplementedError (f"Method 'matches' not implemented for { cls . __name__ } " )
281
+ def matches (cls , * args , ** kwargs ) -> bool :
282
+ raise NotImplementedError ("UnknownModelConfig cannot match anything " )
265
283
266
284
@classmethod
267
- def parse (cls , * args , ** kwargs ):
268
- raise NotImplementedError (f"Method ' parse' not implemented for { cls . __name__ } " )
285
+ def parse (cls , * args , ** kwargs ) -> dict [ str , Any ] :
286
+ raise NotImplementedError ("UnknownModelConfig cannot parse anything " )
269
287
270
288
271
289
class CheckpointConfigBase (ABC , BaseModel ):
@@ -353,7 +371,7 @@ def matches(cls, mod: ModelOnDisk) -> bool:
353
371
354
372
metadata = mod .metadata ()
355
373
return (
356
- metadata .get ("modelspec.sai_model_spec" )
374
+ bool ( metadata .get ("modelspec.sai_model_spec" ) )
357
375
and metadata .get ("ot_branch" ) == "omi_format"
358
376
and metadata ["modelspec.architecture" ].split ("/" )[1 ].lower () == "lora"
359
377
)
@@ -751,6 +769,7 @@ def get_model_discriminator_value(v: Any) -> str:
751
769
Annotated [LlavaOnevisionConfig , LlavaOnevisionConfig .get_tag ()],
752
770
Annotated [ApiModelConfig , ApiModelConfig .get_tag ()],
753
771
Annotated [VideoApiModelConfig , VideoApiModelConfig .get_tag ()],
772
+ Annotated [UnknownModelConfig , UnknownModelConfig .get_tag ()],
754
773
],
755
774
Discriminator (get_model_discriminator_value ),
756
775
]
0 commit comments