Skip to content

Commit 03b3191

Browse files
docs(mm): add comments for identification utils
1 parent aea7e0f commit 03b3191

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

invokeai/backend/model_manager/configs/identification_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def __init__(self, reason: str):
2121

2222

2323
def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
24+
"""Load the diffusers/transformers model config file and return it as a dictionary. The config file is expected
25+
to be in JSON format.
26+
27+
Args:
28+
config_path: The path to the config file, or a set of paths to try.
29+
30+
Returns:
31+
The config file as a dictionary.
32+
33+
Raises:
34+
NotAMatch if the config file is missing or cannot be loaded.
35+
"""
2436
paths_to_check = config_path if isinstance(config_path, set) else {config_path}
2537

2638
problems: dict[Path, str] = {}
@@ -45,6 +57,12 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
4557
def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> str:
4658
"""Load the diffusers/transformers model config file and return the class name.
4759
60+
Args:
61+
config_path: The path to the config file, or a set of paths to try.
62+
63+
Returns:
64+
The class name from the config file.
65+
4866
Raises:
4967
NotAMatch if the config file is missing or does not contain a valid class name.
5068
"""
@@ -69,20 +87,22 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
6987
return config_class_name
7088

7189

72-
def raise_for_class_name(config_path: Path | set[Path], expected: set[str]) -> None:
90+
def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[str]) -> None:
7391
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
7492
7593
Args:
76-
config_path: The path to the config file.
77-
expected: The expected class names.
94+
config_path: The path to the config file, or a set of paths to try.
95+
class_name: The expected class name, or a set of expected class names.
7896
7997
Raises:
8098
NotAMatch if the class name is not in the expected set.
8199
"""
82100

83-
class_name = get_class_name_from_config_dict_or_raise(config_path)
84-
if class_name not in expected:
85-
raise NotAMatchError(f"invalid class name from config: {class_name}")
101+
class_name = {class_name} if isinstance(class_name, str) else class_name
102+
103+
actual_class_name = get_class_name_from_config_dict_or_raise(config_path)
104+
if actual_class_name not in class_name:
105+
raise NotAMatchError(f"invalid class name from config: {actual_class_name}")
86106

87107

88108
def raise_for_override_fields(candidate_config_class: type[BaseModel], override_fields: dict[str, Any]) -> None:
@@ -91,6 +111,9 @@ def raise_for_override_fields(candidate_config_class: type[BaseModel], override_
91111
For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and
92112
the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch.
93113
114+
Internally, this function extracts the pydantic schema for each individual override field from the candidate config
115+
class and validates the override value against that schema. Post-instantiation validators are not run.
116+
94117
Args:
95118
candidate_config_class: The config class that is being tested.
96119
override_fields: The override fields provided by the user.

0 commit comments

Comments
 (0)