Skip to content

Commit 44f24a4

Browse files
committed
Flax
1 parent b8fa81a commit 44f24a4

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ def load_module(name, value):
846846
continue
847847
class_name = passed_class_obj[key].__class__.__name__
848848
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
849+
expected_class_name = expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name
849850
if key == "scheduler" and class_name not in scheduler_types:
850851
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
851852
elif key != "scheduler" and class_name != expected_class_name:

0 commit comments

Comments
 (0)