Skip to content

Commit dba12b6

Browse files
committed
support union
1 parent c5e1e2d commit dba12b6

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import sys
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
25+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, _UnionGenericAlias
2626

2727
import numpy as np
2828
import PIL.Image
@@ -836,11 +836,12 @@ def load_module(name, value):
836836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
837837
scheduler_types = None
838838
if "scheduler" in expected_types:
839-
scheduler_types = expected_types["scheduler"][0]
840-
if isinstance(scheduler_types, enum.EnumMeta):
841-
scheduler_types = list(scheduler_types)
842-
else:
843-
scheduler_types = [str(scheduler_types)]
839+
scheduler_types = []
840+
for scheduler_type in expected_types["scheduler"]:
841+
if isinstance(scheduler_type, enum.EnumMeta):
842+
scheduler_types.extend(list(scheduler_type))
843+
else:
844+
scheduler_types.extend([str(scheduler_type)])
844845
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]
845846

846847
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):

0 commit comments

Comments
 (0)