Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 1c53cf1

Browse files
peakjifardeon
andauthored
build(utils): only download weights index of the selected tensor format (#177)
* build(utils): only download weights index of the selected tensor format * build(utils): preserve markdown files in downloads --------- Co-authored-by: Gideon Giffard <118290024+fardeon@users.noreply.github.com>
1 parent 4b1ce7e commit 1c53cf1

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

utils/download.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,38 @@
1313
sys.exit("usage: python download.py REPO_ID LOCAL_DIR [REVISION]")
1414

1515
if os.getenv("TENSOR_FORMAT") == "safetensors":
16-
tensor_format = "*.safetensors"
16+
allow_patterns = ["*.safetensors", "*.safetensors.index.json"]
1717
else:
18-
tensor_format = "*.bin"
18+
allow_patterns = ["*.bin", "*.bin.index.json"]
19+
20+
ignore_patterns = [
21+
".*",
22+
"*.index.json",
23+
"*.bin",
24+
"*.ckpt",
25+
"*.h5",
26+
"*.mlmodel",
27+
"*.msgpack",
28+
"*.onnx",
29+
"*.ot",
30+
"*.pb",
31+
"*.safetensors",
32+
"*.tar.gz",
33+
"*.tflite",
34+
]
35+
36+
kwargs = {
37+
"repo_id": sys.argv[1],
38+
"local_dir": sys.argv[2],
39+
"revision": sys.argv[3] if len(sys.argv) > 3 else None,
40+
"local_dir_use_symlinks": False,
41+
"resume_download": True,
42+
}
1943

2044
with tempfile.TemporaryDirectory() as cache_dir:
2145
huggingface_hub.snapshot_download(
22-
repo_id=sys.argv[1],
23-
local_dir=sys.argv[2],
24-
revision=sys.argv[3] if len(sys.argv) > 3 else None,
25-
cache_dir=cache_dir,
26-
local_dir_use_symlinks=False,
27-
resume_download=True,
28-
allow_patterns=[tensor_format, "*.json", "*.model", "*.py"],
46+
cache_dir=cache_dir, ignore_patterns=ignore_patterns, **kwargs
47+
)
48+
huggingface_hub.snapshot_download(
49+
cache_dir=cache_dir, allow_patterns=allow_patterns, **kwargs
2950
)

0 commit comments

Comments
 (0)