@@ -54,7 +54,7 @@ def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
54
54
raise NotAMatchError (f"unable to load config file(s): { problems } " )
55
55
56
56
57
- def get_class_name_from_config_dict_or_raise (config_path : Path | set [Path ]) -> str :
57
+ def get_class_name_from_config_dict_or_raise (config : Path | set [Path ] | dict [ str , Any ]) -> str :
58
58
"""Load the diffusers/transformers model config file and return the class name.
59
59
60
60
Args:
@@ -67,7 +67,8 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
67
67
NotAMatch if the config file is missing or does not contain a valid class name.
68
68
"""
69
69
70
- config = get_config_dict_or_raise (config_path )
70
+ if not isinstance (config , dict ):
71
+ config = get_config_dict_or_raise (config )
71
72
72
73
try :
73
74
if "_class_name" in config :
@@ -79,15 +80,15 @@ def get_class_name_from_config_dict_or_raise(config_path: Path | set[Path]) -> s
79
80
else :
80
81
raise ValueError ("missing _class_name or architectures field" )
81
82
except Exception as e :
82
- raise NotAMatchError (f"unable to determine class name from config file: { config_path } " ) from e
83
+ raise NotAMatchError (f"unable to determine class name from config file: { config } " ) from e
83
84
84
85
if not isinstance (config_class_name , str ):
85
86
raise NotAMatchError (f"_class_name or architectures field is not a string: { config_class_name } " )
86
87
87
88
return config_class_name
88
89
89
90
90
- def raise_for_class_name (config_path : Path | set [Path ], class_name : str | set [str ]) -> None :
91
+ def raise_for_class_name (config : Path | set [Path ] | dict [ str , Any ], class_name : str | set [str ]) -> None :
91
92
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
92
93
93
94
Args:
@@ -100,7 +101,7 @@ def raise_for_class_name(config_path: Path | set[Path], class_name: str | set[st
100
101
101
102
class_name = {class_name } if isinstance (class_name , str ) else class_name
102
103
103
- actual_class_name = get_class_name_from_config_dict_or_raise (config_path )
104
+ actual_class_name = get_class_name_from_config_dict_or_raise (config )
104
105
if actual_class_name not in class_name :
105
106
raise NotAMatchError (f"invalid class name from config: { actual_class_name } " )
106
107
0 commit comments