Skip to content

Commit 08b30ff

Browse files
feat(mm): more flexible config matching utils
1 parent a369114 commit 08b30ff

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

invokeai/backend/model_manager/configs/identification_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
5454
raise NotAMatchError(f"unable to load config file(s): {problems}")
5555

5656

57-
def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str:
57+
def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str:
5858
"""Load the diffusers/transformers model config file and return the class name.
5959
6060
Args:
@@ -67,7 +67,8 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
6767
NotAMatch if the config file is missing or does not contain a valid class name.
6868
"""
6969

70-
config = get_config_dict_or_raise(config_path)
70+
if not isinstance(config, dict):
71+
config = get_config_dict_or_raise(config)
7172

7273
try:
7374
if "_class_name" in config:
@@ -79,15 +80,15 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
7980
else:
8081
raise ValueError("missing _class_name or architectures field")
8182
except Exception as e:
82-
raise NotAMatchError(f"unable to determine class name from config file: {config_path}") from e
83+
raise NotAMatchError(f"unable to determine class name from config file: {config}") from e
8384

8485
if not isinstance(config_class_name, str):
8586
raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}")
8687

8788
return config_class_name
8889

8990

90-
def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[str]) -> None:
91+
def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None:
9192
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
9293
9394
Args:
@@ -100,7 +101,7 @@ def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[st
100101

101102
class_name = {class_name} if isinstance(class_name, str) else class_name
102103

103-
actual_class_name = get_class_name_from_config_dict_or_raise(config_path)
104+
actual_class_name = get_class_name_from_config_dict_or_raise(config)
104105
if actual_class_name not in class_name:
105106
raise NotAMatchError(f"invalid class name from config: {actual_class_name}")
106107

0 commit comments

Comments
 (0)