1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import importlib
1415import inspect
1516import re
1617from contextlib import nullcontext
7273}
7374
7475
76+ def _get_single_file_loadable_mapping_class (cls ):
77+ diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
78+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES :
79+ loadable_class = getattr (diffusers_module , loadable_class_str )
80+
81+ if issubclass (cls , loadable_class ):
82+ return loadable_class_str
83+
84+ return None
85+
86+
7587def _get_mapping_function_kwargs (mapping_fn , ** kwargs ):
7688 parameters = inspect .signature (mapping_fn ).parameters
7789
@@ -149,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
149161 ```
150162 """
151163
152- class_name = cls .__name__
153- if class_name not in SINGLE_FILE_LOADABLE_CLASSES :
164+ mapping_class_name = _get_single_file_loadable_mapping_class (cls )
165+ # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
166+ if mapping_class_name is None :
154167 raise ValueError (
155168 f"FromOriginalModelMixin is currently only compatible with { ', ' .join (SINGLE_FILE_LOADABLE_CLASSES .keys ())} "
156169 )
@@ -195,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
195208 revision = revision ,
196209 )
197210
198- mapping_functions = SINGLE_FILE_LOADABLE_CLASSES [class_name ]
211+ mapping_functions = SINGLE_FILE_LOADABLE_CLASSES [mapping_class_name ]
199212
200213 checkpoint_mapping_fn = mapping_functions ["checkpoint_mapping_fn" ]
201214 if original_config :
@@ -207,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
207220 if config_mapping_fn is None :
208221 raise ValueError (
209222 (
210- f"`original_config` has been provided for { class_name } but no mapping function"
223+ f"`original_config` has been provided for { mapping_class_name } but no mapping function"
211224 "was found to convert the original config to a Diffusers config in"
212225 "`diffusers.loaders.single_file_utils`"
213226 )
@@ -267,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
267280 )
268281 if not diffusers_format_checkpoint :
269282 raise SingleFileComponentError (
270- f"Failed to load { class_name } . Weights for this component appear to be missing in the checkpoint."
283+ f"Failed to load { mapping_class_name } . Weights for this component appear to be missing in the checkpoint."
271284 )
272285
273286 ctx = init_empty_weights if is_accelerate_available () else nullcontext
0 commit comments