Skip to content

Commit 420c78c

Browse files
committed
update
1 parent ac4c23c commit 420c78c

File tree

2 files changed

+44
-61
lines changed

2 files changed

+44
-61
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,14 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
165165
variant_file_re = re.compile(
166166
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
167167
)
168-
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
169168
# `text_encoder/pytorch_model.bin.index.fp16.json`
170169
variant_index_re = re.compile(
171170
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
172171
)
172+
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
173+
legacy_variant_index_re = re.compile(
174+
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
175+
)
173176

174177
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
175178
non_variant_file_re = re.compile(
@@ -209,11 +212,16 @@ def filter_with_regex(filenames, pattern_re):
209212
component_non_variants = set()
210213
if variant is not None:
211214
component_variants = filter_with_regex(component_filenames, variant_file_re)
212-
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
213215
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
214216

217+
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
218+
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
219+
220+
if component_variants:
215221
variant_filenames.update(
216-
component_variants if component_variants else component_legacy_variants | component_variant_index_files
222+
component_variants | component_variant_index_files
223+
if component_variants
224+
else component_legacy_variants | component_legacy_variant_index_files
217225
)
218226

219227
else:
@@ -225,7 +233,7 @@ def filter_with_regex(filenames, pattern_re):
225233
usable_filenames.update(variant_filenames)
226234

227235
if len(variant_filenames) == 0 and variant is not None:
228-
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
236+
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
229237
raise ValueError(error_message)
230238

231239
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:

tests/pipelines/test_pipeline_utils.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import contextlib
22
import io
33
import re
4-
from shutil import ignore_patterns
54
import unittest
65

76
import torch
@@ -213,7 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
213212

214213
class VariantCompatibleSiblingsTest(unittest.TestCase):
215214
def test_only_non_variants_downloaded(self):
216-
ignore_patterns = ["*.bin", "*.msgpack"]
215+
ignore_patterns = ["*.bin"]
217216
variant = "fp16"
218217
filenames = [
219218
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -230,7 +229,7 @@ def test_only_non_variants_downloaded(self):
230229
assert all(variant not in f for f in model_filenames)
231230

232231
def test_only_variants_downloaded(self):
233-
ignore_patterns = ["*.bin", "*.msgpack"]
232+
ignore_patterns = ["*.bin"]
234233
variant = "fp16"
235234
filenames = [
236235
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -247,7 +246,7 @@ def test_only_variants_downloaded(self):
247246
assert all(variant in f for f in model_filenames)
248247

249248
def test_mixed_variants_downloaded(self):
250-
ignore_patterns = ["*.bin", "*.msgpack"]
249+
ignore_patterns = ["*.bin"]
251250
variant = "fp16"
252251
non_variant_file = "text_encoder/model.safetensors"
253252
filenames = [
@@ -263,7 +262,7 @@ def test_mixed_variants_downloaded(self):
263262
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
264263

265264
def test_non_variants_in_main_dir_downloaded(self):
266-
ignore_patterns = ["*.bin", "*.msgpack"]
265+
ignore_patterns = ["*.bin"]
267266
variant = "fp16"
268267
filenames = [
269268
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -277,6 +276,7 @@ def test_non_variants_in_main_dir_downloaded(self):
277276
assert all(variant not in f for f in model_filenames)
278277

279278
def test_variants_in_main_dir_downloaded(self):
279+
ignore_patterns = ["*.bin"]
280280
variant = "fp16"
281281
filenames = [
282282
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -292,7 +292,7 @@ def test_variants_in_main_dir_downloaded(self):
292292
assert all(variant in f for f in model_filenames)
293293

294294
def test_mixed_variants_in_main_dir_downloaded(self):
295-
use_safetensors = True
295+
ignore_patterns = ["*.bin"]
296296
variant = "fp16"
297297
non_variant_file = "model.safetensors"
298298
filenames = [
@@ -306,7 +306,7 @@ def test_mixed_variants_in_main_dir_downloaded(self):
306306
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
307307

308308
def test_sharded_variants_in_main_dir_downloaded(self):
309-
use_safetensors = True
309+
ignore_patterns = ["*.bin"]
310310
variant = "fp16"
311311
filenames = [
312312
"diffusion_pytorch_model.safetensors.index.json",
@@ -323,7 +323,7 @@ def test_sharded_variants_in_main_dir_downloaded(self):
323323
assert all(variant in f for f in model_filenames)
324324

325325
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
326-
use_safetensors = True
326+
ignore_patterns = ["*.bin"]
327327
variant = "fp16"
328328
filenames = [
329329
"diffusion_pytorch_model.safetensors.index.json",
@@ -338,7 +338,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
338338
assert all(variant in f for f in model_filenames)
339339

340340
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
341-
use_safetensors = True
341+
ignore_patterns = ["*.bin"]
342342
variant = "fp16"
343343
filenames = [
344344
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -355,7 +355,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
355355
assert all(variant not in f for f in model_filenames)
356356

357357
def test_sharded_non_variants_downloaded(self):
358-
use_safetensors = True
358+
ignore_patterns = ["*.bin"]
359359
variant = "fp16"
360360
filenames = [
361361
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -372,7 +372,7 @@ def test_sharded_non_variants_downloaded(self):
372372
assert all(variant not in f for f in model_filenames)
373373

374374
def test_sharded_variants_downloaded(self):
375-
use_safetensors = True
375+
ignore_patterns = ["*.bin"]
376376
variant = "fp16"
377377
filenames = [
378378
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -387,9 +387,10 @@ def test_sharded_variants_downloaded(self):
387387
filenames, variant=variant, ignore_patterns=ignore_patterns
388388
)
389389
assert all(variant in f for f in model_filenames)
390+
assert model_filenames == variant_filenames
390391

391392
def test_single_variant_with_sharded_non_variant_downloaded(self):
392-
use_safetensors = True
393+
ignore_patterns = ["*.bin"]
393394
variant = "fp16"
394395
filenames = [
395396
"unet/diffusion_pytorch_model.safetensors.index.json",
@@ -404,7 +405,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
404405
assert all(variant in f for f in model_filenames)
405406

406407
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
407-
use_safetensors = True
408+
ignore_patterns = ["*.bin"]
408409
variant = "fp16"
409410
allowed_non_variant = "unet"
410411
filenames = [
@@ -424,7 +425,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
424425
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
425426

426427
def test_sharded_mixed_variants_downloaded(self):
427-
use_safetensors = True
428+
ignore_patterns = ["*.bin"]
428429
variant = "fp16"
429430
allowed_non_variant = "unet"
430431
filenames = [
@@ -445,56 +446,30 @@ def test_sharded_mixed_variants_downloaded(self):
445446
)
446447
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
447448

448-
def test_variant_ignored_if_use_safetensors(self):
449-
use_safetensors = True
450-
variant = "fp16"
451-
filenames = [
452-
f"vae/diffusion_pytorch_model.{variant}.bin",
453-
f"text_encoder/model.{variant}.bin",
454-
f"unet/diffusion_pytorch_model.{variant}.bin",
455-
"vae/diffusion_pytorch_model.safetensors",
456-
"text_encoder/model.safetensors",
457-
"unet/diffusion_pytorch_model.safetensors",
458-
]
459-
model_filenames, variant_filenames = variant_compatible_siblings(
460-
filenames, variant=variant, ignore_patterns=ignore_patterns
461-
)
462-
assert all(variant not in f for f in model_filenames)
463-
464449
def test_downloading_when_no_variant_exists(self):
465-
use_safetensors = True
450+
ignore_patterns = ["*.bin"]
466451
variant = "fp16"
467452
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
468-
model_filenames, variant_filenames = variant_compatible_siblings(
469-
filenames, variant=variant, ignore_patterns=ignore_patterns
470-
)
471-
assert len(model_filenames) != 0
472-
473-
def test_downloading_use_safetensors_no_variant_exists(self):
474-
use_safetensors = True
475-
variant = "fp16"
476-
filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"]
477-
model_filenames, variant_filenames = variant_compatible_siblings(
478-
filenames, variant=variant, ignore_patterns=ignore_patterns
479-
)
480-
assert all(variant not in f for f in model_filenames)
453+
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
454+
model_filenames, variant_filenames = variant_compatible_siblings(
455+
filenames, variant=variant, ignore_patterns=ignore_patterns
456+
)
481457

482458
def test_downloading_use_safetensors_false(self):
483-
use_safetensors = False
484-
variant = "fp16"
459+
ignore_patterns = ["*.safetensors"]
485460
filenames = [
486461
"text_encoder/model.bin",
487462
"unet/diffusion_pytorch_model.bin",
488463
"unet/diffusion_pytorch_model.safetensors",
489464
]
490465
model_filenames, variant_filenames = variant_compatible_siblings(
491-
filenames, variant=variant, ignore_patterns=ignore_patterns
466+
filenames, variant=None, ignore_patterns=ignore_patterns
492467
)
493468

494469
assert all(".safetensors" not in f for f in model_filenames)
495470

496471
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
497-
use_safetensors = True
472+
ignore_patterns = ["*.bin"]
498473
variant = "fp16"
499474
allowed_non_variant = "diffusion_pytorch_model.safetensors"
500475
filenames = [
@@ -506,8 +481,8 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
506481
)
507482
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
508483

509-
def test_download_variants_when_component_has_no_variant(self):
510-
use_safetensors = True
484+
def test_download_variants_when_component_has_no_safetensors_variant(self):
485+
ignore_patterns = None
511486
variant = "fp16"
512487
filenames = [
513488
f"unet/diffusion_pytorch_model.{variant}.bin",
@@ -522,8 +497,8 @@ def test_download_variants_when_component_has_no_variant(self):
522497
f"vae/diffusion_pytorch_model.{variant}.safetensors",
523498
} == model_filenames
524499

525-
def test_download_sharded_variants_when_component_has_no_safetensors_variant(self):
526-
use_safetensors = True
500+
def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
501+
ignore_patterns = ["*.bin"]
527502
variant = "fp16"
528503
filenames = [
529504
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
@@ -538,13 +513,13 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel
538513
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
539514
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
540515
]
541-
model_filenames, variant_filenames = variant_compatible_siblings(
542-
filenames, variant=variant, ignore_patterns=ignore_patterns
543-
)
544-
assert all(variant not in f for f in model_filenames)
516+
with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
517+
model_filenames, variant_filenames = variant_compatible_siblings(
518+
filenames, variant=variant, ignore_patterns=ignore_patterns
519+
)
545520

546521
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
547-
use_safetensors = False
522+
ignore_patterns = ["*.safetensors"]
548523
allowed_non_variant = "unet"
549524
variant = "fp16"
550525
filenames = [

0 commit comments

Comments
 (0)