11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import importlib
14
15
import inspect
15
16
import re
16
17
from contextlib import nullcontext
72
73
}
73
74
74
75
76
+ def _get_single_file_loadable_mapping_class (cls ):
77
+ diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
78
+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES :
79
+ loadable_class = getattr (diffusers_module , loadable_class_str )
80
+
81
+ if issubclass (cls , loadable_class ):
82
+ return loadable_class_str
83
+
84
+ return None
85
+
86
+
75
87
def _get_mapping_function_kwargs (mapping_fn , ** kwargs ):
76
88
parameters = inspect .signature (mapping_fn ).parameters
77
89
@@ -149,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
149
161
```
150
162
"""
151
163
152
- class_name = cls .__name__
153
- if class_name not in SINGLE_FILE_LOADABLE_CLASSES :
164
+ mapping_class_name = _get_single_file_loadable_mapping_class (cls )
165
+ # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
166
+ if mapping_class_name is None :
154
167
raise ValueError (
155
168
f"FromOriginalModelMixin is currently only compatible with { ', ' .join (SINGLE_FILE_LOADABLE_CLASSES .keys ())} "
156
169
)
@@ -195,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
195
208
revision = revision ,
196
209
)
197
210
198
- mapping_functions = SINGLE_FILE_LOADABLE_CLASSES [class_name ]
211
+ mapping_functions = SINGLE_FILE_LOADABLE_CLASSES [mapping_class_name ]
199
212
200
213
checkpoint_mapping_fn = mapping_functions ["checkpoint_mapping_fn" ]
201
214
if original_config :
@@ -207,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
207
220
if config_mapping_fn is None :
208
221
raise ValueError (
209
222
(
210
- f"`original_config` has been provided for { class_name } but no mapping function"
223
+ f"`original_config` has been provided for { mapping_class_name } but no mapping function"
211
224
"was found to convert the original config to a Diffusers config in"
212
225
"`diffusers.loaders.single_file_utils`"
213
226
)
@@ -267,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
267
280
)
268
281
if not diffusers_format_checkpoint :
269
282
raise SingleFileComponentError (
270
- f"Failed to load { class_name } . Weights for this component appear to be missing in the checkpoint."
283
+ f"Failed to load { mapping_class_name } . Weights for this component appear to be missing in the checkpoint."
271
284
)
272
285
273
286
ctx = init_empty_weights if is_accelerate_available () else nullcontext
0 commit comments