Skip to content

Commit 5af8c7f

Browse files
committed
scheduler in expected types
1 parent 44f24a4 commit 5af8c7f

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,20 +834,22 @@ def load_module(name, value):
834834
return True
835835

836836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
837-
scheduler_types = expected_types["scheduler"][0]
838-
if isinstance(scheduler_types, enum.EnumMeta):
839-
scheduler_types = list(scheduler_types)
840-
else:
841-
scheduler_types = [str(scheduler_types)]
842-
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]
837+
scheduler_types = None
838+
if "scheduler" in expected_types:
839+
scheduler_types = expected_types["scheduler"][0]
840+
if isinstance(scheduler_types, enum.EnumMeta):
841+
scheduler_types = list(scheduler_types)
842+
else:
843+
scheduler_types = [str(scheduler_types)]
844+
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]
843845

844846
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
845847
if key not in passed_class_obj:
846848
continue
847849
class_name = passed_class_obj[key].__class__.__name__
848850
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
849851
expected_class_name = expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name
850-
if key == "scheduler" and class_name not in scheduler_types:
852+
if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types:
851853
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
852854
elif key != "scheduler" and class_name != expected_class_name:
853855
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")

0 commit comments

Comments
 (0)