Skip to content

Commit 6d2a80c

Browse files
committed
up
1 parent 219a8ab commit 6d2a80c

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,8 +1834,14 @@ def _is_union(annotation: Any) -> bool:
18341834
return union_type is not None and isinstance(annotation, union_type)
18351835

18361836
def _normalize_annotation(annotation: Any) -> tuple[type, ...]:
1837-
if annotation in (inspect._empty, None) or annotation is Any:
1838-
return ()
1837+
if annotation is inspect._empty:
1838+
return (inspect.Signature.empty,)
1839+
1840+
if annotation is None:
1841+
return (type(None),)
1842+
1843+
if annotation is Any:
1844+
return (Any,)
18391845

18401846
if inspect.isclass(annotation):
18411847
return (annotation,)
@@ -2149,9 +2155,11 @@ def from_pipe(cls, pipeline, **kwargs):
21492155
for name, component in pipeline.components.items():
21502156
if name in expected_modules and name not in passed_class_obj:
21512157
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
2158+
expected = component_types.get(name, ())
21522159
if (
21532160
not isinstance(component, ModelMixin)
2154-
or type(component) in component_types[name]
2161+
or not expected
2162+
or _is_valid_type(component, expected)
21552163
or (component is None and name in cls._optional_components)
21562164
):
21572165
original_class_obj[name] = component

0 commit comments

Comments
 (0)