Skip to content

Commit d81a3d6

Browse files
fix(mm): clip vision identification
1 parent 08b30ff commit d81a3d6

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

invokeai/backend/model_manager/configs/clip_vision.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
1010
from invokeai.backend.model_manager.configs.identification_utils import (
11-
common_config_paths,
12-
raise_for_class_name,
11+
NotAMatchError,
12+
get_class_name_from_config_dict_or_raise,
13+
get_config_dict_or_raise,
1314
raise_for_override_fields,
1415
raise_if_not_dir,
1516
)
@@ -34,11 +35,23 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
3435

3536
raise_for_override_fields(cls, override_fields)
3637

37-
raise_for_class_name(
38-
common_config_paths(mod.path),
39-
{
40-
"CLIPVisionModelWithProjection",
41-
},
42-
)
38+
cls.raise_if_config_doesnt_look_like_clip_vision(mod)
4339

4440
return cls(**override_fields)
41+
42+
@classmethod
43+
def raise_if_config_doesnt_look_like_clip_vision(cls, mod: ModelOnDisk) -> None:
44+
config_dict = get_config_dict_or_raise(mod.path / "config.json")
45+
class_name = get_class_name_from_config_dict_or_raise(config_dict)
46+
47+
if class_name == "CLIPVisionModelWithProjection":
48+
looks_like_clip_vision = True
49+
elif class_name == "CLIPModel" and "vision_config" in config_dict:
50+
looks_like_clip_vision = True
51+
else:
52+
looks_like_clip_vision = False
53+
54+
if not looks_like_clip_vision:
55+
raise NotAMatchError(
56+
f"config class name is {class_name}, not CLIPVisionModelWithProjection or CLIPModel with vision_config"
57+
)

0 commit comments

Comments
 (0)