Skip to content

Commit 905e8d7

Browse files
committed
Created typing_utils.py
1 parent 7c2e5d8 commit 905e8d7

File tree

6 files changed

+101
-81
lines changed

6 files changed

+101
-81
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@
543543
title: Overview
544544
- local: api/schedulers/cm_stochastic_iterative
545545
title: CMStochasticIterativeScheduler
546+
- local: api/schedulers/ddim_cogvideox
547+
title: CogVideoXDDIMScheduler
548+
- local: api/schedulers/multistep_dpm_solver_cogvideox
549+
title: CogVideoXDPMScheduler
546550
- local: api/schedulers/consistency_decoder
547551
title: ConsistencyDecoderScheduler
548552
- local: api/schedulers/cosine_dpm
@@ -551,8 +555,6 @@
551555
title: DDIMInverseScheduler
552556
- local: api/schedulers/ddim
553557
title: DDIMScheduler
554-
- local: api/schedulers/ddim_cogvideox
555-
title: CogVideoXDDIMScheduler
556558
- local: api/schedulers/ddpm
557559
title: DDPMScheduler
558560
- local: api/schedulers/deis
@@ -565,8 +567,6 @@
565567
title: DPMSolverSDEScheduler
566568
- local: api/schedulers/singlestep_dpm_solver
567569
title: DPMSolverSinglestepScheduler
568-
- local: api/schedulers/multistep_dpm_solver_cogvideox
569-
title: CogVideoXDPMScheduler
570570
- local: api/schedulers/edm_multistep_dpm_solver
571571
title: EDMDPMSolverMultistepScheduler
572572
- local: api/schedulers/edm_euler

src/diffusers/loaders/ip_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from safetensors import safe_open
2222

2323
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
24-
from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type
2524
from ..utils import (
2625
USE_PEFT_BACKEND,
26+
_get_detailed_type,
2727
_get_model_file,
28+
_is_valid_type,
2829
is_accelerate_available,
2930
is_torch_version,
3031
is_transformers_available,

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import warnings
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
20+
from typing import Any, Callable, Dict, List, Optional, Union
2121

2222
import requests
2323
import torch
@@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10591059
break
10601060
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
10611061
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1062-
1063-
1064-
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
1065-
"""
1066-
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
1067-
the correct type as well.
1068-
"""
1069-
if not isinstance(class_or_tuple, tuple):
1070-
class_or_tuple = (class_or_tuple,)
1071-
1072-
# Unpack unions
1073-
unpacked_class_or_tuple = []
1074-
for t in class_or_tuple:
1075-
if get_origin(t) is Union:
1076-
unpacked_class_or_tuple.extend(get_args(t))
1077-
else:
1078-
unpacked_class_or_tuple.append(t)
1079-
class_or_tuple = tuple(unpacked_class_or_tuple)
1080-
1081-
if Any in class_or_tuple:
1082-
return True
1083-
1084-
obj_type = type(obj)
1085-
# Classes with obj's type
1086-
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
1087-
1088-
# Singular types (e.g. int, ControlNet, ...)
1089-
# Untyped collections (e.g. List, but not List[int])
1090-
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
1091-
if () in elem_class_or_tuple:
1092-
return True
1093-
# Typed lists or sets
1094-
elif obj_type in (list, set):
1095-
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
1096-
# Typed tuples
1097-
elif obj_type is tuple:
1098-
return any(
1099-
# Tuples with any length and single type (e.g. Tuple[int, ...])
1100-
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
1101-
or
1102-
# Tuples with fixed length and any types (e.g. Tuple[int, str])
1103-
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
1104-
for t in elem_class_or_tuple
1105-
)
1106-
# Typed dicts
1107-
elif obj_type is dict:
1108-
return any(
1109-
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
1110-
for kt, vt in elem_class_or_tuple
1111-
)
1112-
1113-
else:
1114-
return False
1115-
1116-
1117-
def _get_detailed_type(obj: Any) -> Type:
1118-
"""
1119-
Gets a detailed type for an object, including nested types for collections.
1120-
"""
1121-
obj_type = type(obj)
1122-
1123-
if obj_type in (list, set):
1124-
obj_origin_type = List if obj_type is list else Set
1125-
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
1126-
return obj_origin_type[elems_type]
1127-
elif obj_type is tuple:
1128-
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
1129-
elif obj_type is dict:
1130-
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
1131-
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
1132-
return Dict[keys_type, values_type]
1133-
else:
1134-
return obj_type

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
DEPRECATED_REVISION_ARGS,
5555
BaseOutput,
5656
PushToHubMixin,
57+
_get_detailed_type,
58+
_is_valid_type,
5759
is_accelerate_available,
5860
is_accelerate_version,
5961
is_torch_npu_available,
@@ -78,12 +80,10 @@
7880
_fetch_class_library_tuple,
7981
_get_custom_components_and_folders,
8082
_get_custom_pipeline_class,
81-
_get_detailed_type,
8283
_get_final_device_map,
8384
_get_ignore_patterns,
8485
_get_pipeline_class,
8586
_identify_model_variants,
86-
_is_valid_type,
8787
_maybe_raise_error_for_incorrect_transformers,
8888
_maybe_raise_warning_for_inpainting,
8989
_resolve_custom_pipeline_and_cls,

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
convert_state_dict_to_peft,
124124
convert_unet_state_dict_to_peft,
125125
)
126+
from .typing_utils import _get_detailed_type, _is_valid_type
126127

127128

128129
logger = get_logger(__name__)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Typing utilities: Utilities related to type checking and validation
16+
"""
17+
18+
from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin
19+
20+
21+
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
22+
"""
23+
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
24+
the correct type as well.
25+
"""
26+
if not isinstance(class_or_tuple, tuple):
27+
class_or_tuple = (class_or_tuple,)
28+
29+
# Unpack unions
30+
unpacked_class_or_tuple = []
31+
for t in class_or_tuple:
32+
if get_origin(t) is Union:
33+
unpacked_class_or_tuple.extend(get_args(t))
34+
else:
35+
unpacked_class_or_tuple.append(t)
36+
class_or_tuple = tuple(unpacked_class_or_tuple)
37+
38+
if Any in class_or_tuple:
39+
return True
40+
41+
obj_type = type(obj)
42+
# Classes with obj's type
43+
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
44+
45+
# Singular types (e.g. int, ControlNet, ...)
46+
# Untyped collections (e.g. List, but not List[int])
47+
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
48+
if () in elem_class_or_tuple:
49+
return True
50+
# Typed lists or sets
51+
elif obj_type in (list, set):
52+
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
53+
# Typed tuples
54+
elif obj_type is tuple:
55+
return any(
56+
# Tuples with any length and single type (e.g. Tuple[int, ...])
57+
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
58+
or
59+
# Tuples with fixed length and any types (e.g. Tuple[int, str])
60+
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
61+
for t in elem_class_or_tuple
62+
)
63+
# Typed dicts
64+
elif obj_type is dict:
65+
return any(
66+
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
67+
for kt, vt in elem_class_or_tuple
68+
)
69+
70+
else:
71+
return False
72+
73+
74+
def _get_detailed_type(obj: Any) -> Type:
75+
"""
76+
Gets a detailed type for an object, including nested types for collections.
77+
"""
78+
obj_type = type(obj)
79+
80+
if obj_type in (list, set):
81+
obj_origin_type = List if obj_type is list else Set
82+
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
83+
return obj_origin_type[elems_type]
84+
elif obj_type is tuple:
85+
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
86+
elif obj_type is dict:
87+
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
88+
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
89+
return Dict[keys_type, values_type]
90+
else:
91+
return obj_type

0 commit comments

Comments
 (0)