11import contextlib
22import io
33import re
4+ from shutil import ignore_patterns
45import unittest
56
67import torch
@@ -212,7 +213,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
212213
213214class VariantCompatibleSiblingsTest (unittest .TestCase ):
214215 def test_only_non_variants_downloaded (self ):
215- use_safetensors = True
216+ ignore_patterns = [ "*.bin" , "*.msgpack" ]
216217 variant = "fp16"
217218 filenames = [
218219 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -224,12 +225,12 @@ def test_only_non_variants_downloaded(self):
224225 ]
225226
226227 model_filenames , variant_filenames = variant_compatible_siblings (
227- filenames , variant = None , use_safetensors = use_safetensors
228+ filenames , variant = None , ignore_patterns = ignore_patterns
228229 )
229230 assert all (variant not in f for f in model_filenames )
230231
231232 def test_only_variants_downloaded (self ):
232- use_safetensors = True
233+ ignore_patterns = [ "*.bin" , "*.msgpack" ]
233234 variant = "fp16"
234235 filenames = [
235236 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -241,12 +242,12 @@ def test_only_variants_downloaded(self):
241242 ]
242243
243244 model_filenames , variant_filenames = variant_compatible_siblings (
244- filenames , variant = variant , use_safetensors = use_safetensors
245+ filenames , variant = variant , ignore_patterns = ignore_patterns
245246 )
246247 assert all (variant in f for f in model_filenames )
247248
248249 def test_mixed_variants_downloaded (self ):
249- use_safetensors = True
250+ ignore_patterns = [ "*.bin" , "*.msgpack" ]
250251 variant = "fp16"
251252 non_variant_file = "text_encoder/model.safetensors"
252253 filenames = [
@@ -257,12 +258,12 @@ def test_mixed_variants_downloaded(self):
257258 "unet/diffusion_pytorch_model.safetensors" ,
258259 ]
259260 model_filenames , variant_filenames = variant_compatible_siblings (
260- filenames , variant = variant , use_safetensors = use_safetensors
261+ filenames , variant = variant , ignore_patterns = ignore_patterns
261262 )
262263 assert all (variant in f if f != non_variant_file else variant not in f for f in model_filenames )
263264
264265 def test_non_variants_in_main_dir_downloaded (self ):
265- use_safetensors = True
266+ ignore_patterns = [ "*.bin" , "*.msgpack" ]
266267 variant = "fp16"
267268 filenames = [
268269 f"diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -271,12 +272,11 @@ def test_non_variants_in_main_dir_downloaded(self):
271272 f"model.{ variant } .safetensors" ,
272273 ]
273274 model_filenames , variant_filenames = variant_compatible_siblings (
274- filenames , variant = None , use_safetensors = use_safetensors
275+ filenames , variant = None , ignore_patterns = ignore_patterns
275276 )
276277 assert all (variant not in f for f in model_filenames )
277278
278279 def test_variants_in_main_dir_downloaded (self ):
279- use_safetensors = True
280280 variant = "fp16"
281281 filenames = [
282282 f"diffusion_pytorch_model.{ variant } .safetensors" ,
@@ -287,7 +287,7 @@ def test_variants_in_main_dir_downloaded(self):
287287 "diffusion_pytorch_model.safetensors" ,
288288 ]
289289 model_filenames , variant_filenames = variant_compatible_siblings (
290- filenames , variant = variant , use_safetensors = use_safetensors
290+ filenames , variant = variant , ignore_patterns = ignore_patterns
291291 )
292292 assert all (variant in f for f in model_filenames )
293293
@@ -301,7 +301,7 @@ def test_mixed_variants_in_main_dir_downloaded(self):
301301 "model.safetensors" ,
302302 ]
303303 model_filenames , variant_filenames = variant_compatible_siblings (
304- filenames , variant = variant , use_safetensors = use_safetensors
304+ filenames , variant = variant , ignore_patterns = ignore_patterns
305305 )
306306 assert all (variant in f if f != non_variant_file else variant not in f for f in model_filenames )
307307
@@ -318,7 +318,7 @@ def test_sharded_variants_in_main_dir_downloaded(self):
318318 f"diffusion_pytorch_model.safetensors.index.{ variant } .json" ,
319319 ]
320320 model_filenames , variant_filenames = variant_compatible_siblings (
321- filenames , variant = variant , use_safetensors = use_safetensors
321+ filenames , variant = variant , ignore_patterns = ignore_patterns
322322 )
323323 assert all (variant in f for f in model_filenames )
324324
@@ -333,7 +333,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
333333 f"diffusion_pytorch_model.{ variant } .safetensors" ,
334334 ]
335335 model_filenames , variant_filenames = variant_compatible_siblings (
336- filenames , variant = variant , use_safetensors = use_safetensors
336+ filenames , variant = variant , ignore_patterns = ignore_patterns
337337 )
338338 assert all (variant in f for f in model_filenames )
339339
@@ -350,7 +350,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
350350 f"diffusion_pytorch_model.{ variant } -00002-of-00002.safetensors" ,
351351 ]
352352 model_filenames , variant_filenames = variant_compatible_siblings (
353- filenames , variant = None , use_safetensors = use_safetensors
353+ filenames , variant = None , ignore_patterns = ignore_patterns
354354 )
355355 assert all (variant not in f for f in model_filenames )
356356
@@ -367,7 +367,7 @@ def test_sharded_non_variants_downloaded(self):
367367 f"unet/diffusion_pytorch_model.{ variant } -00002-of-00002.safetensors" ,
368368 ]
369369 model_filenames , variant_filenames = variant_compatible_siblings (
370- filenames , variant = None , use_safetensors = use_safetensors
370+ filenames , variant = None , ignore_patterns = ignore_patterns
371371 )
372372 assert all (variant not in f for f in model_filenames )
373373
@@ -384,7 +384,7 @@ def test_sharded_variants_downloaded(self):
384384 f"unet/diffusion_pytorch_model.{ variant } -00002-of-00002.safetensors" ,
385385 ]
386386 model_filenames , variant_filenames = variant_compatible_siblings (
387- filenames , variant = variant , use_safetensors = use_safetensors
387+ filenames , variant = variant , ignore_patterns = ignore_patterns
388388 )
389389 assert all (variant in f for f in model_filenames )
390390
@@ -399,7 +399,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
399399 f"unet/diffusion_pytorch_model.{ variant } .safetensors" ,
400400 ]
401401 model_filenames , variant_filenames = variant_compatible_siblings (
402- filenames , variant = variant , use_safetensors = use_safetensors
402+ filenames , variant = variant , ignore_patterns = ignore_patterns
403403 )
404404 assert all (variant in f for f in model_filenames )
405405
@@ -419,7 +419,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
419419 "unet/diffusion_pytorch_model-00003-of-00003.safetensors" ,
420420 ]
421421 model_filenames , variant_filenames = variant_compatible_siblings (
422- filenames , variant = variant , use_safetensors = use_safetensors
422+ filenames , variant = variant , ignore_patterns = ignore_patterns
423423 )
424424 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
425425
@@ -441,7 +441,7 @@ def test_sharded_mixed_variants_downloaded(self):
441441 "vae/diffusion_pytorch_model-00003-of-00003.safetensors" ,
442442 ]
443443 model_filenames , variant_filenames = variant_compatible_siblings (
444- filenames , variant = variant , use_safetensors = use_safetensors
444+ filenames , variant = variant , ignore_patterns = ignore_patterns
445445 )
446446 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
447447
@@ -457,7 +457,7 @@ def test_variant_ignored_if_use_safetensors(self):
457457 "unet/diffusion_pytorch_model.safetensors" ,
458458 ]
459459 model_filenames , variant_filenames = variant_compatible_siblings (
460- filenames , variant = variant , use_safetensors = use_safetensors
460+ filenames , variant = variant , ignore_patterns = ignore_patterns
461461 )
462462 assert all (variant not in f for f in model_filenames )
463463
@@ -466,7 +466,7 @@ def test_downloading_when_no_variant_exists(self):
466466 variant = "fp16"
467467 filenames = ["model.safetensors" , "diffusion_pytorch_model.safetensors" ]
468468 model_filenames , variant_filenames = variant_compatible_siblings (
469- filenames , variant = variant , use_safetensors = use_safetensors
469+ filenames , variant = variant , ignore_patterns = ignore_patterns
470470 )
471471 assert len (model_filenames ) != 0
472472
@@ -475,7 +475,7 @@ def test_downloading_use_safetensors_no_variant_exists(self):
475475 variant = "fp16"
476476 filenames = ["text_encoder/model.bin" , "unet/diffusion_pytorch_model.bin" ]
477477 model_filenames , variant_filenames = variant_compatible_siblings (
478- filenames , variant = variant , use_safetensors = use_safetensors
478+ filenames , variant = variant , ignore_patterns = ignore_patterns
479479 )
480480 assert all (variant not in f for f in model_filenames )
481481
@@ -488,7 +488,7 @@ def test_downloading_use_safetensors_false(self):
488488 "unet/diffusion_pytorch_model.safetensors" ,
489489 ]
490490 model_filenames , variant_filenames = variant_compatible_siblings (
491- filenames , variant = variant , use_safetensors = use_safetensors
491+ filenames , variant = variant , ignore_patterns = ignore_patterns
492492 )
493493
494494 assert all (".safetensors" not in f for f in model_filenames )
@@ -502,7 +502,7 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
502502 "diffusion_pytorch_model.safetensors" ,
503503 ]
504504 model_filenames , variant_filenames = variant_compatible_siblings (
505- filenames , variant = variant , use_safetensors = use_safetensors
505+ filenames , variant = variant , ignore_patterns = ignore_patterns
506506 )
507507 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
508508
@@ -515,7 +515,7 @@ def test_download_variants_when_component_has_no_variant(self):
515515 f"vae/diffusion_pytorch_model.{ variant } .safetensors" ,
516516 ]
517517 model_filenames , variant_filenames = variant_compatible_siblings (
518- filenames , variant = variant , use_safetensors = use_safetensors
518+ filenames , variant = variant , ignore_patterns = ignore_patterns
519519 )
520520 assert {
521521 f"unet/diffusion_pytorch_model.{ variant } .bin" ,
@@ -539,7 +539,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel
539539 f"vae/diffusion_pytorch_model.{ variant } -00001-of-00002.bin" ,
540540 ]
541541 model_filenames , variant_filenames = variant_compatible_siblings (
542- filenames , variant = variant , use_safetensors = use_safetensors
542+ filenames , variant = variant , ignore_patterns = ignore_patterns
543543 )
544544 assert all (variant not in f for f in model_filenames )
545545
@@ -561,7 +561,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and
561561 f"vae/diffusion_pytorch_model.{ variant } -00001-of-00002.bin" ,
562562 ]
563563 model_filenames , variant_filenames = variant_compatible_siblings (
564- filenames , variant = variant , use_safetensors = use_safetensors
564+ filenames , variant = variant , ignore_patterns = ignore_patterns
565565 )
566566 assert all (variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames )
567567
0 commit comments