Skip to content

Commit 12be157

Browse files
committed
download
1 parent a218be3 commit 12be157

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from .. import __version__
3030
from ..utils import (
31+
FLASHPACK_WEIGHTS_NAME,
3132
FLAX_WEIGHTS_NAME,
3233
ONNX_EXTERNAL_WEIGHTS_NAME,
3334
ONNX_WEIGHTS_NAME,
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
194195
FLAX_WEIGHTS_NAME,
195196
ONNX_WEIGHTS_NAME,
196197
ONNX_EXTERNAL_WEIGHTS_NAME,
198+
FLASHPACK_WEIGHTS_NAME,
197199
]
198200

199201
if is_transformers_available():
@@ -1091,6 +1093,7 @@ def _get_ignore_patterns(
10911093
allow_pickle: bool,
10921094
use_onnx: bool,
10931095
is_onnx: bool,
1096+
use_flashpack: bool,
10941097
variant: Optional[str] = None,
10951098
) -> List[str]:
10961099
if (
@@ -1116,6 +1119,9 @@ def _get_ignore_patterns(
11161119
if not use_onnx:
11171120
ignore_patterns += ["*.onnx", "*.pb"]
11181121

1122+
elif use_flashpack:
1123+
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
1124+
11191125
else:
11201126
ignore_patterns = ["*.safetensors", "*.msgpack"]
11211127

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15511551
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
15521552
option should only be set to `True` for repositories you trust and in which you have read the code, as
15531553
it will execute code present on the Hub on your local machine.
1554+
use_onnx (`bool`, *optional*, defaults to `False`):
1555+
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, ONNX weights
1556+
will never be downloaded.
15541557
15551558
Returns:
15561559
`os.PathLike`:
@@ -1575,6 +1578,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15751578
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
15761579
trust_remote_code = kwargs.pop("trust_remote_code", False)
15771580
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
1581+
use_flashpack = kwargs.pop("use_flashpack", True)
15781582

15791583
if dduf_file:
15801584
if custom_pipeline:
@@ -1694,6 +1698,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
16941698
allow_pickle,
16951699
use_onnx,
16961700
pipeline_class._is_onnx,
1701+
use_flashpack,
16971702
variant,
16981703
)
16991704

0 commit comments

Comments
 (0)