Skip to content

Commit a216b0b

Browse files
luocfprimeDN6sayakpaul
authored
fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (#8454)
* fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (cherry picked from commit 9285997) * Update src/diffusers/loaders/single_file_model.py Co-authored-by: Dhruv Nair <[email protected]> * Update single_file_model.py * Update single_file_model.py --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 150142c commit a216b0b

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import importlib
1415
import inspect
1516
import re
1617
from contextlib import nullcontext
@@ -72,6 +73,17 @@
7273
}
7374

7475

76+
def _get_single_file_loadable_mapping_class(cls):
77+
diffusers_module = importlib.import_module(__name__.split(".")[0])
78+
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
79+
loadable_class = getattr(diffusers_module, loadable_class_str)
80+
81+
if issubclass(cls, loadable_class):
82+
return loadable_class_str
83+
84+
return None
85+
86+
7587
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
7688
parameters = inspect.signature(mapping_fn).parameters
7789

@@ -149,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
149161
```
150162
"""
151163

152-
class_name = cls.__name__
153-
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
164+
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
165+
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
166+
if mapping_class_name is None:
154167
raise ValueError(
155168
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
156169
)
@@ -195,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
195208
revision=revision,
196209
)
197210

198-
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
211+
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
199212

200213
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
201214
if original_config:
@@ -207,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
207220
if config_mapping_fn is None:
208221
raise ValueError(
209222
(
210-
f"`original_config` has been provided for {class_name} but no mapping function"
223+
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
211224
"was found to convert the original config to a Diffusers config in"
212225
"`diffusers.loaders.single_file_utils`"
213226
)
@@ -267,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
267280
)
268281
if not diffusers_format_checkpoint:
269282
raise SingleFileComponentError(
270-
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
283+
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
271284
)
272285

273286
ctx = init_empty_weights if is_accelerate_available() else nullcontext

0 commit comments

Comments
 (0)