Skip to content

Commit 2089700

Browse files
committed
update
1 parent 9f9db3b commit 2089700

File tree

3 files changed

+184
-24
lines changed

3 files changed

+184
-24
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141141
return True
142142

143143

144-
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
144+
def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -> Union[List[os.PathLike], str]:
145145
weight_names = [
146146
WEIGHTS_NAME,
147147
SAFETENSORS_WEIGHTS_NAME,
@@ -188,24 +188,85 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
188188
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
189189
non_variant_filenames = non_variant_weights | non_variant_indexes
190190

191-
# all variant filenames will be used by default
192-
usable_filenames = set(variant_filenames)
193-
194191
def find_component(filename):
195192
if not len(filename.split("/")) == 2:
196193
return
197194
component = filename.split("/")[0]
198195
return component
199196

200-
def has_variant(filename, variant_filenames):
197+
def convert_to_variant(filename):
198+
if "index" in filename:
199+
variant_filename = filename.replace("index", f"index.{variant}")
200+
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
201+
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
202+
else:
203+
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
204+
return variant_filename
205+
206+
def has_sharded_variant(filename, variant, variant_filenames):
201207
component = find_component(filename)
208+
# If component exists check for sharded variant index filename
209+
# If component doesn't exist check main dir for sharded variant index filename
202210
component = component + "/" if component else ""
211+
variant_index_re = re.compile(
212+
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213+
)
214+
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215+
216+
def has_non_sharded_variant(filename, variant, variant_filenames):
217+
component = find_component(filename)
218+
component = component + "/" if component else ""
219+
base_name = filename.split("/")[-1]
220+
221+
# Only apply to sharded files (those with the index format)
222+
if not (non_variant_file_re.match(base_name) or non_variant_index_re.match(base_name)):
223+
return False
203224

204-
return any(f.startswith(component) for f in variant_filenames)
225+
# Check if there's a non-sharded variant in the same component
226+
non_sharded_variants = [
227+
f
228+
for f in variant_filenames
229+
if f.startswith(component) and not re.search(transformers_index_format, f.split("/")[-1])
230+
]
231+
return any(non_sharded_variants)
232+
233+
if use_safetensors:
234+
# Keep only safetensors and index files
235+
non_variant_filenames = {
236+
f
237+
for f in non_variant_filenames
238+
if f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1])
239+
}
240+
if variant is not None:
241+
variant_filenames = {
242+
f for f in variant_filenames if f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1])
243+
}
244+
else:
245+
# Exclude safetensors files but keep index files
246+
non_variant_filenames = {
247+
f
248+
for f in non_variant_filenames
249+
if not f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1])
250+
}
251+
if variant is not None:
252+
variant_filenames = {
253+
f
254+
for f in variant_filenames
255+
if not f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1])
256+
}
257+
258+
# all variant filenames will be used by default
259+
usable_filenames = set(variant_filenames)
205260

206261
for filename in non_variant_filenames:
207-
# If a variant exists skip adding to allowed patterns
208-
if has_variant(filename, variant_filenames):
262+
if convert_to_variant(filename) in variant_filenames:
263+
continue
264+
265+
# If a sharded variant exists skip adding to allowed patterns
266+
if has_sharded_variant(filename, variant, variant_filenames):
267+
continue
268+
269+
if has_non_sharded_variant(filename, variant, variant_filenames):
209270
continue
210271

211272
usable_filenames.add(filename)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13751375
)
13761376
logger.warning(warn_msg)
13771377

1378-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1378+
model_filenames, variant_filenames = variant_compatible_siblings(
1379+
filenames, variant=variant, use_safetensors=use_safetensors
1380+
)
13791381

13801382
config_file = hf_hub_download(
13811383
pretrained_model_name,

tests/pipelines/test_pipeline_utils.py

Lines changed: 112 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
212212

213213
class VariantCompatibleSiblingsTest(unittest.TestCase):
214214
def test_only_non_variants_downloaded(self):
215+
use_safetensors = True
215216
variant = "fp16"
216217
filenames = [
217218
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -222,10 +223,13 @@ def test_only_non_variants_downloaded(self):
222223
"unet/diffusion_pytorch_model.safetensors",
223224
]
224225

225-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
226+
model_filenames, variant_filenames = variant_compatible_siblings(
227+
filenames, variant=None, use_safetensors=use_safetensors
228+
)
226229
assert all(variant not in f for f in model_filenames)
227230

228231
def test_only_variants_downloaded(self):
232+
use_safetensors = True
229233
variant = "fp16"
230234
filenames = [
231235
f"vae/diffusion_pytorch_model.{variant}.safetensors",
@@ -236,10 +240,13 @@ def test_only_variants_downloaded(self):
236240
"unet/diffusion_pytorch_model.safetensors",
237241
]
238242

239-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
243+
model_filenames, variant_filenames = variant_compatible_siblings(
244+
filenames, variant=variant, use_safetensors=use_safetensors
245+
)
240246
assert all(variant in f for f in model_filenames)
241247

242248
def test_mixed_variants_downloaded(self):
249+
use_safetensors = True
243250
variant = "fp16"
244251
non_variant_file = "text_encoder/model.safetensors"
245252
filenames = [
@@ -249,21 +256,27 @@ def test_mixed_variants_downloaded(self):
249256
f"unet/diffusion_pytorch_model.{variant}.safetensors",
250257
"unet/diffusion_pytorch_model.safetensors",
251258
]
252-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
259+
model_filenames, variant_filenames = variant_compatible_siblings(
260+
filenames, variant=variant, use_safetensors=use_safetensors
261+
)
253262
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
254263

255264
def test_non_variants_in_main_dir_downloaded(self):
265+
use_safetensors = True
256266
variant = "fp16"
257267
filenames = [
258268
f"diffusion_pytorch_model.{variant}.safetensors",
259269
"diffusion_pytorch_model.safetensors",
260270
"model.safetensors",
261271
f"model.{variant}.safetensors",
262272
]
263-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
273+
model_filenames, variant_filenames = variant_compatible_siblings(
274+
filenames, variant=None, use_safetensors=use_safetensors
275+
)
264276
assert all(variant not in f for f in model_filenames)
265277

266278
def test_variants_in_main_dir_downloaded(self):
279+
use_safetensors = True
267280
variant = "fp16"
268281
filenames = [
269282
f"diffusion_pytorch_model.{variant}.safetensors",
@@ -273,21 +286,27 @@ def test_variants_in_main_dir_downloaded(self):
273286
f"diffusion_pytorch_model.{variant}.safetensors",
274287
"diffusion_pytorch_model.safetensors",
275288
]
276-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
289+
model_filenames, variant_filenames = variant_compatible_siblings(
290+
filenames, variant=variant, use_safetensors=use_safetensors
291+
)
277292
assert all(variant in f for f in model_filenames)
278293

279294
def test_mixed_variants_in_main_dir_downloaded(self):
295+
use_safetensors = True
280296
variant = "fp16"
281297
non_variant_file = "model.safetensors"
282298
filenames = [
283299
f"diffusion_pytorch_model.{variant}.safetensors",
284300
"diffusion_pytorch_model.safetensors",
285301
"model.safetensors",
286302
]
287-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
303+
model_filenames, variant_filenames = variant_compatible_siblings(
304+
filenames, variant=variant, use_safetensors=use_safetensors
305+
)
288306
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
289307

290308
def test_sharded_variants_in_main_dir_downloaded(self):
309+
use_safetensors = True
291310
variant = "fp16"
292311
filenames = [
293312
"diffusion_pytorch_model.safetensors.index.json",
@@ -298,10 +317,13 @@ def test_sharded_variants_in_main_dir_downloaded(self):
298317
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
299318
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
300319
]
301-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
320+
model_filenames, variant_filenames = variant_compatible_siblings(
321+
filenames, variant=variant, use_safetensors=use_safetensors
322+
)
302323
assert all(variant in f for f in model_filenames)
303324

304325
def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
326+
use_safetensors = True
305327
variant = "fp16"
306328
filenames = [
307329
"diffusion_pytorch_model.safetensors.index.json",
@@ -310,10 +332,13 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
310332
"diffusion_pytorch_model-00003-of-00003.safetensors",
311333
f"diffusion_pytorch_model.{variant}.safetensors",
312334
]
313-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
335+
model_filenames, variant_filenames = variant_compatible_siblings(
336+
filenames, variant=variant, use_safetensors=use_safetensors
337+
)
314338
assert all(variant in f for f in model_filenames)
315339

316340
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
341+
use_safetensors = True
317342
variant = "fp16"
318343
filenames = [
319344
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -324,10 +349,13 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
324349
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
325350
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
326351
]
327-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
352+
model_filenames, variant_filenames = variant_compatible_siblings(
353+
filenames, variant=None, use_safetensors=use_safetensors
354+
)
328355
assert all(variant not in f for f in model_filenames)
329356

330357
def test_sharded_non_variants_downloaded(self):
358+
use_safetensors = True
331359
variant = "fp16"
332360
filenames = [
333361
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -338,10 +366,13 @@ def test_sharded_non_variants_downloaded(self):
338366
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
339367
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
340368
]
341-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
369+
model_filenames, variant_filenames = variant_compatible_siblings(
370+
filenames, variant=None, use_safetensors=use_safetensors
371+
)
342372
assert all(variant not in f for f in model_filenames)
343373

344374
def test_sharded_variants_downloaded(self):
375+
use_safetensors = True
345376
variant = "fp16"
346377
filenames = [
347378
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
@@ -352,10 +383,13 @@ def test_sharded_variants_downloaded(self):
352383
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
353384
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
354385
]
355-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
386+
model_filenames, variant_filenames = variant_compatible_siblings(
387+
filenames, variant=variant, use_safetensors=use_safetensors
388+
)
356389
assert all(variant in f for f in model_filenames)
357390

358391
def test_single_variant_with_sharded_non_variant_downloaded(self):
392+
use_safetensors = True
359393
variant = "fp16"
360394
filenames = [
361395
"unet/diffusion_pytorch_model.safetensors.index.json",
@@ -364,10 +398,13 @@ def test_single_variant_with_sharded_non_variant_downloaded(self):
364398
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
365399
f"unet/diffusion_pytorch_model.{variant}.safetensors",
366400
]
367-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
401+
model_filenames, variant_filenames = variant_compatible_siblings(
402+
filenames, variant=variant, use_safetensors=use_safetensors
403+
)
368404
assert all(variant in f for f in model_filenames)
369405

370406
def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
407+
use_safetensors = True
371408
variant = "fp16"
372409
allowed_non_variant = "unet"
373410
filenames = [
@@ -381,10 +418,13 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
381418
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
382419
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
383420
]
384-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
421+
model_filenames, variant_filenames = variant_compatible_siblings(
422+
filenames, variant=variant, use_safetensors=use_safetensors
423+
)
385424
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
386425

387426
def test_sharded_mixed_variants_downloaded(self):
427+
use_safetensors = True
388428
variant = "fp16"
389429
allowed_non_variant = "unet"
390430
filenames = [
@@ -400,15 +440,72 @@ def test_sharded_mixed_variants_downloaded(self):
400440
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
401441
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
402442
]
403-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
443+
model_filenames, variant_filenames = variant_compatible_siblings(
444+
filenames, variant=variant, use_safetensors=use_safetensors
445+
)
404446
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
405447

448+
def test_variant_ignored_if_use_safetensors(self):
449+
use_safetensors = True
450+
variant = "fp16"
451+
filenames = [
452+
f"vae/diffusion_pytorch_model.{variant}.bin",
453+
f"text_encoder/model.{variant}.bin",
454+
f"unet/diffusion_pytorch_model.{variant}.bin",
455+
"vae/diffusion_pytorch_model.safetensors",
456+
"text_encoder/model.safetensors",
457+
"unet/diffusion_pytorch_model.safetensors",
458+
]
459+
model_filenames, variant_filenames = variant_compatible_siblings(
460+
filenames, variant=variant, use_safetensors=use_safetensors
461+
)
462+
assert all(variant not in f for f in model_filenames)
463+
406464
def test_downloading_when_no_variant_exists(self):
465+
use_safetensors = True
407466
variant = "fp16"
408467
filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
409-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
468+
model_filenames, variant_filenames = variant_compatible_siblings(
469+
filenames, variant=variant, use_safetensors=use_safetensors
470+
)
410471
assert len(model_filenames) != 0
411472

473+
def test_downloading_use_safetensors_no_variant_exists(self):
474+
use_safetensors = True
475+
variant = "fp16"
476+
filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"]
477+
model_filenames, variant_filenames = variant_compatible_siblings(
478+
filenames, variant=variant, use_safetensors=use_safetensors
479+
)
480+
assert all(variant not in f for f in model_filenames)
481+
482+
def test_downloading_use_safetensors_false(self):
483+
use_safetensors = False
484+
variant = "fp16"
485+
filenames = [
486+
"text_encoder/model.bin",
487+
"unet/diffusion_pytorch_model.bin",
488+
"unet/diffusion_pytorch_model.safetensors",
489+
]
490+
model_filenames, variant_filenames = variant_compatible_siblings(
491+
filenames, variant=variant, use_safetensors=use_safetensors
492+
)
493+
494+
assert all(".safetensors" not in f for f in model_filenames)
495+
496+
def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
497+
use_safetensors = True
498+
variant = "fp16"
499+
allowed_non_variant = "diffusion_pytorch_model.safetensors"
500+
filenames = [
501+
f"unet/diffusion_pytorch_model.{variant}.safetensors",
502+
"diffusion_pytorch_model.safetensors",
503+
]
504+
model_filenames, variant_filenames = variant_compatible_siblings(
505+
filenames, variant=variant, use_safetensors=use_safetensors
506+
)
507+
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
508+
412509

413510
class ProgressBarTests(unittest.TestCase):
414511
def get_dummy_components_image_generation(self):

0 commit comments

Comments
 (0)