Skip to content

Commit ac4c23c

Browse files
committed
update
1 parent c40f60c commit ac4c23c

File tree

2 files changed

+51
-58
lines changed

2 files changed

+51
-58
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
165165
variant_file_re = re.compile(
166166
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
167167
)
168+
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
168169
# `text_encoder/pytorch_model.bin.index.fp16.json`
169170
variant_index_re = re.compile(
170171
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
@@ -177,28 +178,16 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
177178
# `text_encoder/pytorch_model.bin.index.json`
178179
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
179180

180-
def filter_for_compatible_extensions(filenames, variant=None, ignore_patterns=None):
181-
def extension_filter(f):
182-
return not any(f.endswith(pattern) for pattern in ignore_patterns)
181+
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
182+
if not ignore_patterns:
183+
return filenames
183184

184-
tensor_files = {f for f in filenames if extension_filter(f)}
185-
non_variant_indexes = {
186-
f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
187-
}
188-
variant_indexes = {
189-
f
190-
for f in filenames
191-
if variant is not None and variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
192-
}
193-
194-
return tensor_files | non_variant_indexes | variant_indexes
185+
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
186+
# interested in the extension name
187+
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
195188

196-
def filter_for_weights_and_indexes(filenames, file_re, index_re):
197-
weights = {f for f in filenames if file_re.match(f.split("/")[-1]) is not None}
198-
indexes = {f for f in filenames if index_re.match(f.split("/")[-1]) is not None}
199-
filtered_filenames = weights | indexes
200-
201-
return filtered_filenames
189+
def filter_with_regex(filenames, pattern_re):
190+
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
202191

203192
# Group files by component
204193
components = {}
@@ -213,23 +202,27 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re):
213202
usable_filenames = set()
214203
variant_filenames = set()
215204
for component, component_filenames in components.items():
216-
component_filenames = filter_for_compatible_extensions(
217-
component_filenames, variant=variant, ignore_patterns=ignore_patterns
218-
)
205+
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
219206

220207
component_variants = set()
208+
component_legacy_variants = set()
209+
component_non_variants = set()
221210
if variant is not None:
222-
component_variants = filter_for_weights_and_indexes(component_filenames, variant_file_re, variant_index_re)
211+
component_variants = filter_with_regex(component_filenames, variant_file_re)
212+
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
213+
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
223214

224-
if component_variants:
225-
variant_filenames.update(component_variants)
226-
usable_filenames.update(component_variants)
215+
variant_filenames.update(
216+
component_variants if component_variants else component_legacy_variants | component_variant_index_files
217+
)
227218

228219
else:
229-
component_non_variants = filter_for_weights_and_indexes(
230-
component_filenames, non_variant_file_re, non_variant_index_re
231-
)
232-
usable_filenames.update(component_non_variants)
220+
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
221+
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
222+
223+
usable_filenames.update(component_non_variants | component_variant_index_files)
224+
225+
usable_filenames.update(variant_filenames)
233226

234227
if len(variant_filenames) == 0 and variant is not None:
235228
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."

tests/pipelines/test_pipeline_utils.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import io
33
import re
4+
from shutil import ignore_patterns
45
import unittest
56

67
import torch
@@ -212,7 +213,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
212213

213214
class VariantCompatibleSiblingsTest(unittest.TestCase):
214215
def test_only_non_variants_downloaded(self):
215-
use_safetensors = True
216+
ignore_patterns = ["*.bin", "*.msgpack"]
216217
variant = "fp16"
217218
filenames = [
218219
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -224,12 +225,12 @@ def test_only_non_variants_downloaded(self):
224225
]
225226

226227
model_filenames, variant_filenames = variant_compatible_siblings(
227-
filenames, variant=None, use_safetensors=use_safetensors
228+
filenames, variant=None, ignore_patterns=ignore_patterns
228229
)
229230
assert all(variant not in f for f in model_filenames)
230231

231232
def test_only_variants_downloaded(self):
232-
use_safetensors = True
233+
ignore_patterns = ["*.bin", "*.msgpack"]
233234
variant = "fp16"
234235
filenames = [
235236
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -241,12 +242,12 @@ def test_only_variants_downloaded(self):
241242
]
242243

243244
model_filenames, variant_filenames = variant_compatible_siblings(
244-
filenames, variant=variant, use_safetensors=use_safetensors
245+
filenames, variant=variant, ignore_patterns=ignore_patterns
245246
)
246247
assert all(variant in f for f in model_filenames)
247248

248249
def test_mixed_variants_downloaded(self):
249-
use_safetensors = True
250+
ignore_patterns = ["*.bin", "*.msgpack"]
250251
variant = "fp16"
251252
non_variant_file = "text_encoder/model.safetensors"
252253
filenames = [
@@ -257,12 +258,12 @@ def test_mixed_variants_downloaded(self):
257258
"unet/diffusion_pytorch_model.safetensors",
258259
]
259260
model_filenames, variant_filenames = variant_compatible_siblings(
260-
filenames, variant=variant, use_safetensors=use_safetensors
261+
filenames, variant=variant, ignore_patterns=ignore_patterns
261262
)
262263
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
263264

264265
def test_non_variants_in_main_dir_downloaded(self):
265-
use_safetensors = True
266+
ignore_patterns = ["*.bin", "*.msgpack"]
266267
variant = "fp16"
267268
filenames = [
268269
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -271,12 +272,11 @@ def test_non_variants_in_main_dir_downloaded(self):
271272
f"model.{variant}.safetensors",
272273
]
273274
model_filenames, variant_filenames = variant_compatible_siblings(
274-
filenames, variant=None, use_safetensors=use_safetensors
275+
filenames, variant=None, ignore_patterns=ignore_patterns
275276
)
276277
assert all(variant not in f for f in model_filenames)
277278

278279
def test_variants_in_main_dir_downloaded(self):
279-
use_safetensors = True
280280
variant = "fp16"
281281
filenames = [
282282
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -287,7 +287,7 @@ def test_variants_in_main_dir_downloaded(self):
287287
"diffusion_pytorch_model.safetensors",
288288
]
289289
model_filenames, variant_filenames = variant_compatible_siblings(
290-
filenames, variant=variant, use_safetensors=use_safetensors
290+
filenames, variant=variant, ignore_patterns=ignore_patterns
291291
)
292292
assert all(variant in f for f in model_filenames)
293293

@@ -301,7 +301,7 @@ def test_mixed_variants_in_main_dir_downloaded(self):
301301
"model.safetensors",
302302
]
303303
model_filenames, variant_filenames = variant_compatible_siblings(
304-
filenames, variant=variant, use_safetensors=use_safetensors
304+
filenames, variant=variant, ignore_patterns=ignore_patterns
305305
)
306306
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
307307

@@ -318,7 +318,7 @@ def test_sharded_variants_in_main_dir_downloaded(self):
318318
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
319319
]
320320
model_filenames, variant_filenames = variant_compatible_siblings(
321-
filenames, variant=variant, use_safetensors=use_safetensors
321+
filenames, variant=variant, ignore_patterns=ignore_patterns
322322
)
323323
assert all(variant in f for f in model_filenames)
324324

@@ -333,7 +333,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
333333
f"diffusion_pytorch_model.{variant}.safetensors",
334334
]
335335
model_filenames, variant_filenames = variant_compatible_siblings(
336-
filenames, variant=variant, use_safetensors=use_safetensors
336+
filenames, variant=variant, ignore_patterns=ignore_patterns
337337
)
338338
assert all(variant in f for f in model_filenames)
339339

@@ -350,7 +350,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
350350
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
351351
]
352352
model_filenames, variant_filenames = variant_compatible_siblings(
353-
filenames, variant=None, use_safetensors=use_safetensors
353+
filenames, variant=None, ignore_patterns=ignore_patterns
354354
)
355355
assert all(variant not in f for f in model_filenames)
356356

@@ -367,7 +367,7 @@ def test_sharded_non_variants_downloaded(self):
367367
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
368368
]
369369
model_filenames, variant_filenames = variant_compatible_siblings(
370-
filenames, variant=None, use_safetensors=use_safetensors
370+
filenames, variant=None, ignore_patterns=ignore_patterns
371371
)
372372
assert all(variant not in f for f in model_filenames)
373373

@@ -384,7 +384,7 @@ def test_sharded_variants_downloaded(self):
384384
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
385385
]
386386
model_filenames, variant_filenames = variant_compatible_siblings(
387-
filenames, variant=variant, use_safetensors=use_safetensors
387+
filenames, variant=variant, ignore_patterns=ignore_patterns
388388
)
389389
assert all(variant in f for f in model_filenames)
390390

@@ -399,7 +399,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
399399
f"unet/diffusion_pytorch_model.{variant}.safetensors",
400400
]
401401
model_filenames, variant_filenames = variant_compatible_siblings(
402-
filenames, variant=variant, use_safetensors=use_safetensors
402+
filenames, variant=variant, ignore_patterns=ignore_patterns
403403
)
404404
assert all(variant in f for f in model_filenames)
405405

@@ -419,7 +419,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
419419
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
420420
]
421421
model_filenames, variant_filenames = variant_compatible_siblings(
422-
filenames, variant=variant, use_safetensors=use_safetensors
422+
filenames, variant=variant, ignore_patterns=ignore_patterns
423423
)
424424
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
425425

@@ -441,7 +441,7 @@ def test_sharded_mixed_variants_downloaded(self):
441441
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
442442
]
443443
model_filenames, variant_filenames = variant_compatible_siblings(
444-
filenames, variant=variant, use_safetensors=use_safetensors
444+
filenames, variant=variant, ignore_patterns=ignore_patterns
445445
)
446446
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
447447

@@ -457,7 +457,7 @@ def test_variant_ignored_if_use_safetensors(self):
457457
"unet/diffusion_pytorch_model.safetensors",
458458
]
459459
model_filenames, variant_filenames = variant_compatible_siblings(
460-
filenames, variant=variant, use_safetensors=use_safetensors
460+
filenames, variant=variant, ignore_patterns=ignore_patterns
461461
)
462462
assert all(variant not in f for f in model_filenames)
463463

@@ -466,7 +466,7 @@ def test_downloading_when_no_variant_exists(self):
466466
variant = "fp16"
467467
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
468468
model_filenames, variant_filenames = variant_compatible_siblings(
469-
filenames, variant=variant, use_safetensors=use_safetensors
469+
filenames, variant=variant, ignore_patterns=ignore_patterns
470470
)
471471
assert len(model_filenames) != 0
472472

@@ -475,7 +475,7 @@ def test_downloading_use_safetensors_no_variant_exists(self):
475475
variant = "fp16"
476476
filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"]
477477
model_filenames, variant_filenames = variant_compatible_siblings(
478-
filenames, variant=variant, use_safetensors=use_safetensors
478+
filenames, variant=variant, ignore_patterns=ignore_patterns
479479
)
480480
assert all(variant not in f for f in model_filenames)
481481

@@ -488,7 +488,7 @@ def test_downloading_use_safetensors_false(self):
488488
"unet/diffusion_pytorch_model.safetensors",
489489
]
490490
model_filenames, variant_filenames = variant_compatible_siblings(
491-
filenames, variant=variant, use_safetensors=use_safetensors
491+
filenames, variant=variant, ignore_patterns=ignore_patterns
492492
)
493493

494494
assert all(".safetensors" not in f for f in model_filenames)
@@ -502,7 +502,7 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
502502
"diffusion_pytorch_model.safetensors",
503503
]
504504
model_filenames, variant_filenames = variant_compatible_siblings(
505-
filenames, variant=variant, use_safetensors=use_safetensors
505+
filenames, variant=variant, ignore_patterns=ignore_patterns
506506
)
507507
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
508508

@@ -515,7 +515,7 @@ def test_download_variants_when_component_has_no_variant(self):
515515
f"vae/diffusion_pytorch_model.{variant}.safetensors",
516516
]
517517
model_filenames, variant_filenames = variant_compatible_siblings(
518-
filenames, variant=variant, use_safetensors=use_safetensors
518+
filenames, variant=variant, ignore_patterns=ignore_patterns
519519
)
520520
assert {
521521
f"unet/diffusion_pytorch_model.{variant}.bin",
@@ -539,7 +539,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel
539539
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
540540
]
541541
model_filenames, variant_filenames = variant_compatible_siblings(
542-
filenames, variant=variant, use_safetensors=use_safetensors
542+
filenames, variant=variant, ignore_patterns=ignore_patterns
543543
)
544544
assert all(variant not in f for f in model_filenames)
545545

@@ -561,7 +561,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and
561561
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
562562
]
563563
model_filenames, variant_filenames = variant_compatible_siblings(
564-
filenames, variant=variant, use_safetensors=use_safetensors
564+
filenames, variant=variant, ignore_patterns=ignore_patterns
565565
)
566566
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
567567

0 commit comments

Comments
 (0)