Skip to content

Commit 2fd4508

Browse files
committed
Check correct model type is passed to from_pretrained
1 parent 43534a8 commit 2fd4508

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,13 @@ def load_module(name, value):
833833

834834
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
835835

836+
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
837+
if key not in passed_class_obj:
838+
continue
839+
class_name = passed_class_obj[key].__class__.__name__
840+
if class_name != expected_class_name:
841+
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
842+
836843
# Special case: safety_checker must be loaded separately when using `from_flax`
837844
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
838845
raise NotImplementedError(

0 commit comments

Comments
 (0)