@@ -21,6 +21,18 @@ def __init__(self, reason: str):
21
21
22
22
23
23
def get_config_dict_or_raise (config_path : Path | set [Path ]) -> dict [str , Any ]:
24
+ """Load the diffusers/transformers model config file and return it as a dictionary. The config file is expected
25
+ to be in JSON format.
26
+
27
+ Args:
28
+ config_path: The path to the config file, or a set of paths to try.
29
+
30
+ Returns:
31
+ The config file as a dictionary.
32
+
33
+ Raises:
34
+ NotAMatch if the config file is missing or cannot be loaded.
35
+ """
24
36
paths_to_check = config_path if isinstance (config_path , set ) else {config_path }
25
37
26
38
problems : dict [Path , str ] = {}
@@ -45,6 +57,12 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
45
57
def get_class_name_from_config_dict_or_raise (config_path : Path | set [Path ]) -> str :
46
58
"""Load the diffusers/transformers model config file and return the class name.
47
59
60
+ Args:
61
+ config_path: The path to the config file, or a set of paths to try.
62
+
63
+ Returns:
64
+ The class name from the config file.
65
+
48
66
Raises:
49
67
NotAMatch if the config file is missing or does not contain a valid class name.
50
68
"""
@@ -69,20 +87,22 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
69
87
return config_class_name
70
88
71
89
72
- def raise_for_class_name (config_path : Path | set [Path ], expected : set [str ]) -> None :
90
+ def raise_for_class_name (config_path : Path | set [Path ], class_name : str | set [str ]) -> None :
73
91
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
74
92
75
93
Args:
76
- config_path: The path to the config file.
77
- expected : The expected class names.
94
+ config_path: The path to the config file, or a set of paths to try .
95
+ class_name : The expected class name, or a set of expected class names.
78
96
79
97
Raises:
80
98
NotAMatch if the class name is not in the expected set.
81
99
"""
82
100
83
- class_name = get_class_name_from_config_dict_or_raise (config_path )
84
- if class_name not in expected :
85
- raise NotAMatchError (f"invalid class name from config: { class_name } " )
101
+ class_name = {class_name } if isinstance (class_name , str ) else class_name
102
+
103
+ actual_class_name = get_class_name_from_config_dict_or_raise (config_path )
104
+ if actual_class_name not in class_name :
105
+ raise NotAMatchError (f"invalid class name from config: { actual_class_name } " )
86
106
87
107
88
108
def raise_for_override_fields (candidate_config_class : type [BaseModel ], override_fields : dict [str , Any ]) -> None :
@@ -91,6 +111,9 @@ def raise_for_override_fields(candidate_config_class: type[BaseModel], override_
91
111
For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and
92
112
the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch.
93
113
114
+ Internally, this function extracts the pydantic schema for each individual override field from the candidate config
115
+ class and validates the override value against that schema. Post-instantiation validators are not run.
116
+
94
117
Args:
95
118
candidate_config_class: The config class that is being tested.
96
119
override_fields: The override fields provided by the user.
0 commit comments