Skip to content

Commit 44912ec

Browse files
committed
create download/packaging scripts for pretrained weights
1 parent 2bebb1f commit 44912ec

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import requests # type: ignore[import-untyped]
2+
import zipfile
3+
import hashlib
4+
from pathlib import Path
5+
6+
7+
def download_zip(url, download_path):
8+
response = requests.get(url)
9+
response.raise_for_status()
10+
with open(download_path, "wb") as zip_file:
11+
zip_file.write(response.content)
12+
13+
14+
def extract_zip(zip_path, extract_to):
15+
Path(extract_to).mkdir(parents=True, exist_ok=True)
16+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
17+
zip_ref.extractall(extract_to)
18+
19+
20+
def verify_sha256(file_path, expected_sha256):
21+
sha256_hash = hashlib.sha256()
22+
with open(file_path, "rb") as f:
23+
for byte_block in iter(lambda: f.read(4096), b""):
24+
sha256_hash.update(byte_block)
25+
calculated_sha256 = sha256_hash.hexdigest()
26+
27+
if calculated_sha256 == expected_sha256:
28+
print("SHA256 checksum verified successfully.")
29+
else:
30+
print("SHA256 checksum verification failed.")
31+
raise ValueError("Checksum does not match!")
32+
33+
34+
if __name__ == "__main__":
35+
zip_url = "https://github.com/MECLabTUDA/NCALab/releases/download/v0.3.2/pretrained_weights.zip"
36+
root_path = Path(__file__).parent / ".."
37+
download_path = (root_path / "pretrained_weights.zip").resolve()
38+
extract_to = root_path
39+
40+
expected_sha256 = "7122e784d57341d2832ed8a34deb0c308a860d47242109e22a424d2572eb42f5"
41+
42+
download_zip(zip_url, download_path)
43+
verify_sha256(download_path, expected_sha256)
44+
extract_zip(download_path, extract_to)

scripts/pack_example_weights.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import zipfile
2+
from pathlib import Path
3+
4+
import click
5+
6+
7+
def zip_pth_files(zip_filename):
8+
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
9+
for pth_file in Path("tasks").rglob("*.pth"):
10+
zipf.write(pth_file, pth_file.relative_to(Path(".")))
11+
12+
if __name__ == "__main__":
13+
zip_filename = "pretrained_weights.zip"
14+
zip_path = (Path(__file__) / ".." / ".." / zip_filename).resolve()
15+
zip_pth_files(zip_path)
16+
click.secho(f"Done. You'll find all pretrained weights in {zip_path}.")

0 commit comments

Comments
 (0)