Skip to content

Commit a4bdc97

Browse files
committed
update
1 parent 2089700 commit a4bdc97

File tree

2 files changed

+107
-80
lines changed

2 files changed

+107
-80
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 48 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -177,99 +177,67 @@ def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -
177177
# `text_encoder/pytorch_model.bin.index.json`
178178
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
179179

180-
if variant is not None:
181-
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
182-
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
183-
variant_filenames = variant_weights | variant_indexes
184-
else:
185-
variant_filenames = set()
180+
def filter_for_compatible_extensions(filenames, variant=None, use_safetensors=True):
181+
def is_safetensors(filename):
182+
return ".safetensors" in filename
186183

187-
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
188-
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
189-
non_variant_filenames = non_variant_weights | non_variant_indexes
184+
def is_not_safetensors(filename):
185+
return ".safetensors" not in filename
190186

191-
def find_component(filename):
192-
if not len(filename.split("/")) == 2:
193-
return
194-
component = filename.split("/")[0]
195-
return component
196-
197-
def convert_to_variant(filename):
198-
if "index" in filename:
199-
variant_filename = filename.replace("index", f"index.{variant}")
200-
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
201-
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
187+
if use_safetensors and is_safetensors_compatible(filenames):
188+
extension_filter = is_safetensors
202189
else:
203-
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
204-
return variant_filename
205-
206-
def has_sharded_variant(filename, variant, variant_filenames):
207-
component = find_component(filename)
208-
# If component exists check for sharded variant index filename
209-
# If component doesn't exist check main dir for sharded variant index filename
210-
component = component + "/" if component else ""
211-
variant_index_re = re.compile(
212-
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213-
)
214-
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215-
216-
def has_non_sharded_variant(filename, variant, variant_filenames):
217-
component = find_component(filename)
218-
component = component + "/" if component else ""
219-
base_name = filename.split("/")[-1]
220-
221-
# Only apply to sharded files (those with the index format)
222-
if not (non_variant_file_re.match(base_name) or non_variant_index_re.match(base_name)):
223-
return False
190+
extension_filter = is_not_safetensors
224191

225-
# Check if there's a non-sharded variant in the same component
226-
non_sharded_variants = [
227-
f
228-
for f in variant_filenames
229-
if f.startswith(component) and not re.search(transformers_index_format, f.split("/")[-1])
230-
]
231-
return any(non_sharded_variants)
232-
233-
if use_safetensors:
234-
# Keep only safetensors and index files
235-
non_variant_filenames = {
236-
f
237-
for f in non_variant_filenames
238-
if f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1])
192+
tensor_files = {f for f in filenames if extension_filter(f)}
193+
non_variant_indexes = {
194+
f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
239195
}
240-
if variant is not None:
241-
variant_filenames = {
242-
f for f in variant_filenames if f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1])
243-
}
244-
else:
245-
# Exclude safetensors files but keep index files
246-
non_variant_filenames = {
196+
variant_indexes = {
247197
f
248-
for f in non_variant_filenames
249-
if not f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1])
198+
for f in filenames
199+
if variant is not None and variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f)
250200
}
251-
if variant is not None:
252-
variant_filenames = {
253-
f
254-
for f in variant_filenames
255-
if not f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1])
256-
}
257201

258-
# all variant filenames will be used by default
259-
usable_filenames = set(variant_filenames)
202+
return tensor_files | non_variant_indexes | variant_indexes
260203

261-
for filename in non_variant_filenames:
262-
if convert_to_variant(filename) in variant_filenames:
263-
continue
204+
def filter_for_weights_and_indexes(filenames, file_re, index_re):
205+
weights = {f for f in filenames if file_re.match(f.split("/")[-1]) is not None}
206+
indexes = {f for f in filenames if index_re.match(f.split("/")[-1]) is not None}
207+
filtered_filenames = weights | indexes
264208

265-
# If a sharded variant exists skip adding to allowed patterns
266-
if has_sharded_variant(filename, variant, variant_filenames):
267-
continue
209+
return filtered_filenames
268210

269-
if has_non_sharded_variant(filename, variant, variant_filenames):
211+
# Group files by component
212+
components = {}
213+
for filename in filenames:
214+
if not len(filename.split("/")) == 2:
215+
components.setdefault("", []).append(filename)
270216
continue
271217

272-
usable_filenames.add(filename)
218+
component, _ = filename.split("/")
219+
components.setdefault(component, []).append(filename)
220+
221+
usable_filenames = set()
222+
variant_filenames = set()
223+
for component, component_filenames in components.items():
224+
component_filenames = filter_for_compatible_extensions(
225+
component_filenames, variant=variant, use_safetensors=use_safetensors
226+
)
227+
228+
component_variants = set()
229+
if variant is not None:
230+
component_variants = filter_for_weights_and_indexes(component_filenames, variant_file_re, variant_index_re)
231+
232+
if component_variants:
233+
variant_filenames.update(component_variants)
234+
usable_filenames.update(component_variants)
235+
236+
else:
237+
component_non_variants = filter_for_weights_and_indexes(
238+
component_filenames, non_variant_file_re, non_variant_index_re
239+
)
240+
usable_filenames.update(component_non_variants)
273241

274242
return usable_filenames, variant_filenames
275243

tests/pipelines/test_pipeline_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,65 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
506506
)
507507
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
508508

509+
def test_download_variants_when_component_has_no_variant(self):
510+
use_safetensors = True
511+
variant = "fp16"
512+
filenames = [
513+
f"unet/diffusion_pytorch_model.{variant}.bin",
514+
"vae/diffusion_pytorch_model.safetensors",
515+
f"vae/diffusion_pytorch_model.{variant}.safetensors",
516+
]
517+
model_filenames, variant_filenames = variant_compatible_siblings(
518+
filenames, variant=variant, use_safetensors=use_safetensors
519+
)
520+
assert {
521+
f"unet/diffusion_pytorch_model.{variant}.bin",
522+
f"vae/diffusion_pytorch_model.{variant}.safetensors",
523+
} == model_filenames
524+
525+
def test_download_sharded_variants_when_component_has_no_safetensors_variant(self):
526+
use_safetensors = True
527+
variant = "fp16"
528+
filenames = [
529+
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
530+
"vae/diffusion_pytorch_model.safetensors.index.json",
531+
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
532+
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
533+
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
534+
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
535+
"unet/diffusion_pytorch_model.safetensors.index.json",
536+
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
537+
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
538+
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
539+
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
540+
]
541+
model_filenames, variant_filenames = variant_compatible_siblings(
542+
filenames, variant=variant, use_safetensors=use_safetensors
543+
)
544+
assert all(variant not in f for f in model_filenames)
545+
546+
def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
547+
use_safetensors = False
548+
allowed_non_variant = "unet"
549+
variant = "fp16"
550+
filenames = [
551+
f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
552+
"vae/diffusion_pytorch_model.safetensors.index.json",
553+
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
554+
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
555+
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
556+
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
557+
"unet/diffusion_pytorch_model.safetensors.index.json",
558+
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
559+
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
560+
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
561+
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
562+
]
563+
model_filenames, variant_filenames = variant_compatible_siblings(
564+
filenames, variant=variant, use_safetensors=use_safetensors
565+
)
566+
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
567+
509568

510569
class ProgressBarTests(unittest.TestCase):
511570
def get_dummy_components_image_generation(self):

0 commit comments

Comments
 (0)