Skip to content

Commit b0d17ad

Browse files
committed
init test
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 1a10fa0 commit b0d17ad

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
9393

9494

95-
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
95+
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
9696
"""
9797
Checking for safetensors compatibility:
9898
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,28 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
103103
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
104104
extension is replaced with ".safetensors"
105105
"""
106+
weight_names = [
107+
WEIGHTS_NAME,
108+
SAFETENSORS_WEIGHTS_NAME,
109+
FLAX_WEIGHTS_NAME,
110+
ONNX_WEIGHTS_NAME,
111+
ONNX_EXTERNAL_WEIGHTS_NAME,
112+
]
113+
114+
if is_transformers_available():
115+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
116+
117+
# model_pytorch, diffusion_model_pytorch, ...
118+
weight_prefixes = [w.split(".")[0] for w in weight_names]
119+
# .bin, .safetensors, ...
120+
weight_suffixs = [w.split(".")[-1] for w in weight_names]
121+
# -00001-of-00002
122+
transformers_index_format = r"\d{5}-of-\d{5}"
123+
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
124+
non_variant_file_re = re.compile(
125+
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
126+
)
127+
106128
passed_components = passed_components or []
107129
if folder_names:
108130
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -130,6 +152,8 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
130152
for component, component_filenames in components.items():
131153
matches = []
132154
for component_filename in component_filenames:
155+
if variant is None:
156+
component_filename = filter_with_regex(component_filename, non_variant_file_re)
133157
filename, extension = os.path.splitext(component_filename)
134158

135159
match_exists = extension == ".safetensors"
@@ -158,6 +182,8 @@ def filter_model_files(filenames):
158182

159183
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
160184

185+
def filter_with_regex(filenames, pattern_re):
186+
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
161187

162188
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
163189
weight_names = [
@@ -207,9 +233,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
207233
# interested in the extension name
208234
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
209235

210-
def filter_with_regex(filenames, pattern_re):
211-
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
212-
213236
# Group files by component
214237
components = {}
215238
for filename in filenames:
@@ -997,7 +1020,7 @@ def _get_ignore_patterns(
9971020
use_safetensors
9981021
and not allow_pickle
9991022
and not is_safetensors_compatible(
1000-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1023+
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
10011024
)
10021025
):
10031026
raise EnvironmentError(
@@ -1008,7 +1031,7 @@ def _get_ignore_patterns(
10081031
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
10091032

10101033
elif use_safetensors and is_safetensors_compatible(
1011-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1034+
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
10121035
):
10131036
ignore_patterns = ["*.bin", "*.msgpack"]
10141037

0 commit comments

Comments
 (0)