Skip to content

Commit 02b0892

Browse files
committed
update
1 parent b79e720 commit 02b0892

File tree

2 files changed

+71
-21
lines changed

2 files changed

+71
-21
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,21 @@
2222
import sys
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import (Any, Callable, Dict, List, Optional, Union, get_args,
26-
get_origin)
25+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
2726

2827
import numpy as np
2928
import PIL.Image
3029
import requests
3130
import torch
32-
from huggingface_hub import (DDUFEntry, ModelCard, create_repo,
33-
hf_hub_download, model_info, read_dduf_file,
34-
snapshot_download)
31+
from huggingface_hub import (
32+
DDUFEntry,
33+
ModelCard,
34+
create_repo,
35+
hf_hub_download,
36+
model_info,
37+
read_dduf_file,
38+
snapshot_download,
39+
)
3540
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
3641
from packaging import version
3742
from requests.exceptions import HTTPError
@@ -45,28 +50,51 @@
4550
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4651
from ..quantizers.bitsandbytes.utils import _check_bnb_status
4752
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48-
from ..utils import (CONFIG_NAME, DEPRECATED_REVISION_ARGS, BaseOutput,
49-
PushToHubMixin, is_accelerate_available,
50-
is_accelerate_version, is_torch_npu_available,
51-
is_torch_version, is_transformers_version, logging,
52-
numpy_to_pil)
53-
from ..utils.hub_utils import (_check_legacy_sharding_variant_format,
54-
load_or_create_model_card, populate_model_card)
53+
from ..utils import (
54+
CONFIG_NAME,
55+
DEPRECATED_REVISION_ARGS,
56+
BaseOutput,
57+
PushToHubMixin,
58+
is_accelerate_available,
59+
is_accelerate_version,
60+
is_torch_npu_available,
61+
is_torch_version,
62+
is_transformers_version,
63+
logging,
64+
numpy_to_pil,
65+
)
66+
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
5567
from ..utils.torch_utils import is_compiled_module
5668

69+
5770
if is_torch_npu_available():
5871
import torch_npu # noqa: F401
5972

6073
from .pipeline_loading_utils import (
61-
ALL_IMPORTABLE_CLASSES, CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME,
62-
LOADABLE_CLASSES, _download_dduf_file, _fetch_class_library_tuple,
63-
_get_custom_components_and_folders, _get_custom_pipeline_class,
64-
_get_final_device_map, _get_ignore_patterns, _get_pipeline_class,
65-
_identify_model_variants, _maybe_raise_error_for_incorrect_transformers,
66-
_maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls,
67-
_unwrap_model, _update_init_kwargs_with_connected_pipeline,
68-
filter_model_files, load_sub_model, maybe_raise_or_warn,
69-
variant_compatible_siblings, warn_deprecated_model_variant)
74+
ALL_IMPORTABLE_CLASSES,
75+
CONNECTED_PIPES_KEYS,
76+
CUSTOM_PIPELINE_FILE_NAME,
77+
LOADABLE_CLASSES,
78+
_download_dduf_file,
79+
_fetch_class_library_tuple,
80+
_get_custom_components_and_folders,
81+
_get_custom_pipeline_class,
82+
_get_final_device_map,
83+
_get_ignore_patterns,
84+
_get_pipeline_class,
85+
_identify_model_variants,
86+
_maybe_raise_error_for_incorrect_transformers,
87+
_maybe_raise_warning_for_inpainting,
88+
_resolve_custom_pipeline_and_cls,
89+
_unwrap_model,
90+
_update_init_kwargs_with_connected_pipeline,
91+
filter_model_files,
92+
load_sub_model,
93+
maybe_raise_or_warn,
94+
variant_compatible_siblings,
95+
warn_deprecated_model_variant,
96+
)
97+
7098

7199
if is_accelerate_available():
72100
import accelerate

tests/pipelines/test_pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,28 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and
540540
)
541541
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
542542

543+
def test_download_onnx_models(self):
544+
ignore_patterns = ["*.safetensors"]
545+
filenames = [
546+
"vae/model.onnx",
547+
"unet/model.onnx",
548+
]
549+
model_filenames, variant_filenames = variant_compatible_siblings(
550+
filenames, variant=None, ignore_patterns=ignore_patterns
551+
)
552+
assert model_filenames == set(filenames)
553+
554+
def test_download_flax_models(self):
555+
ignore_patterns = ["*.safetensors", "*.bin"]
556+
filenames = [
557+
"vae/diffusion_flax_model.msgpack",
558+
"unet/diffusion_flax_model.msgpack",
559+
]
560+
model_filenames, variant_filenames = variant_compatible_siblings(
561+
filenames, variant=None, ignore_patterns=ignore_patterns
562+
)
563+
assert model_filenames == set(filenames)
564+
543565

544566
class ProgressBarTests(unittest.TestCase):
545567
def get_dummy_components_image_generation(self):

0 commit comments

Comments
 (0)