Skip to content

Commit 185a78f

Browse files
committed
Flax, skip scheduler
1 parent 78c6e68 commit 185a78f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,9 +834,10 @@ def load_module(name, value):
834834
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
835835

836836
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
837-
if key not in passed_class_obj:
837+
if key not in passed_class_obj or key == "scheduler":
838838
continue
839839
class_name = passed_class_obj[key].__class__.__name__
840+
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
840841
if class_name != expected_class_name:
841842
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
842843

0 commit comments

Comments
 (0)