Skip to content

Commit 9f0ae2f

Browse files
committed
update
1 parent 403417e commit 9f0ae2f

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,29 +197,16 @@ def find_component(filename):
197197
component = filename.split("/")[0]
198198
return component
199199

200-
def has_sharded_variant(component, variant, variant_filenames):
201-
# If component exists check for sharded variant index filename
202-
# If component doesn't exist check main dir for sharded variant index filename
200+
def has_variant(filename, variant_filenames):
201+
component = find_component(filename)
203202
component = component + "/" if component else ""
204-
variant_index_re = re.compile(
205-
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
206-
)
207-
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
208203

209-
def has_variant(component, variant_filenames):
210-
component = component + "/" if component else ""
211204
# Check for any variant file in this component
212205
return any(f.startswith(component) for f in variant_filenames)
213206

214207
for filename in non_variant_filenames:
215-
component = find_component(filename)
216-
217-
# Determine if sharded variant exists based on index file
218-
if has_sharded_variant(component, variant, variant_filenames):
219-
continue
220-
221208
# If a variant exists skip adding to allowed patterns
222-
if has_variant(component, variant_filenames):
209+
if has_variant(filename, variant_filenames):
223210
continue
224211

225212
usable_filenames.add(filename)

tests/pipelines/test_pipeline_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,20 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
313313
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
314314
assert all(variant in f for f in model_filenames)
315315

316+
def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
317+
variant = "fp16"
318+
filenames = [
319+
f"diffusion_pytorch_model.safetensors.index.{variant}.json",
320+
"diffusion_pytorch_model.safetensors.index.json",
321+
"diffusion_pytorch_model-00001-of-00003.safetensors",
322+
"diffusion_pytorch_model-00002-of-00003.safetensors",
323+
"diffusion_pytorch_model-00003-of-00003.safetensors",
324+
f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
325+
f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
326+
]
327+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
328+
assert all(variant not in f for f in model_filenames)
329+
316330
def test_sharded_non_variants_downloaded(self):
317331
variant = "fp16"
318332
filenames = [

0 commit comments

Comments
 (0)