Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit ea3ba20

Browse files
committed
util: net: cached_download_unpack_archive
Signed-off-by: John Andersen <[email protected]>
1 parent 60675e2 commit ea3ba20

File tree

8 files changed

+140
-18
lines changed

8 files changed

+140
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
- Tensorflow hub NLP models.
1010
- Notes on development dependencies in `setup.py` files to codebase notes.
1111
- Test for `cached_download`
12+
- `dffml.util.net.cached_download_unpack_archive` to run a cached download and
13+
unpack the archive, very useful for testing. Documented on the Networking
14+
Helpers API docs page.
1215
- Directions on how to read the CI under the Git and GitHub page of the
1316
contributing documentation.
1417
- HTTP API

dffml/util/net.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import shutil
12
import hashlib
23
import pathlib
34
import functools
45
import urllib.request
56
from typing import List
67

8+
from .os import chdir
9+
710

811
class HashValidationError(Exception):
912
"""
@@ -34,10 +37,16 @@ def __str__(self):
3437
return f"Protocol of URL {self.url!r} is not in allowlist: {self.allowlist!r}"
3538

3639

40+
# Default list of URL protocols allowed
3741
DEFAULT_PROTOCOL_ALLOWLIST: List[str] = ["https://"]
3842

3943

4044
def sync_urlopen(url, protocol_allowlist=DEFAULT_PROTOCOL_ALLOWLIST):
45+
"""
46+
Check that ``url`` has a protocol defined in ``protocol_allowlist``, then
47+
return the result of calling :py:func:`urllib.request.urlopen` passing it
48+
``url``.
49+
"""
4150
allowed_protocol = False
4251
for protocol in protocol_allowlist:
4352
if url.startswith(protocol):
@@ -54,6 +63,24 @@ def cached_download(
5463
expected_hash,
5564
protocol_allowlist=DEFAULT_PROTOCOL_ALLOWLIST,
5665
):
66+
"""
67+
Download a file and verify the hash of the downloaded file. If the file
68+
already exists and the hash matches, do not re-download the file.
69+
70+
Examples
71+
--------
72+
73+
>>> @cached_download(
74+
... "https://github.com/intel/dffml/raw/152c2b92535fac6beec419236f8639b0d75d707d/MANIFEST.in",
75+
... "MANIFEST.in",
76+
... "f7aadf5cdcf39f161a779b4fa77ec56a49630cf7680e21fb3dc6c36ce2d8c6fae0d03d5d3094a6aec4fea1561393c14c",
77+
... )
78+
... async def first_line_in_manifest_152c2b(manifest):
79+
... return manifest.read_text().split()[:2]
80+
>>>
81+
>>> asyncio.run(first_line_in_manifest_152c2b())
82+
['include', 'README.md']
83+
"""
5784
target_path = pathlib.Path(target_path)
5885

5986
def validate_hash(error: bool = True):
@@ -69,7 +96,7 @@ def validate_hash(error: bool = True):
6996
def mkwrapper(func):
7097
@functools.wraps(func)
7198
async def wrapper(*args, **kwds):
72-
args = list(args) + [str(target_path)]
99+
args = list(args) + [target_path]
73100
if not target_path.is_file() or not validate_hash(error=False):
74101
# TODO(p5) Blocking request in coroutine
75102
with sync_urlopen(
@@ -82,3 +109,60 @@ async def wrapper(*args, **kwds):
82109
return wrapper
83110

84111
return mkwrapper
112+
113+
114+
def cached_download_unpack_archive(
115+
url,
116+
file_path,
117+
directory_path,
118+
expected_hash,
119+
protocol_allowlist=DEFAULT_PROTOCOL_ALLOWLIST,
120+
):
121+
"""
122+
Download an archive and extract it to a directory on disk.
123+
124+
Verify the hash of the downloaded file. If the hash matches the file is not
125+
re-downloaded.
126+
127+
.. warning::
128+
129+
This function does not verify the integrity of the unpacked archive on
130+
disk. Only the downloaded file.
131+
132+
Examples
133+
--------
134+
135+
>>> @cached_download_unpack_archive(
136+
... "https://github.com/intel/dffml/archive/152c2b92535fac6beec419236f8639b0d75d707d.tar.gz",
137+
... "dffml.tar.gz",
138+
... "dffml",
139+
... "32ba082cd8056ff4ddcb68691a590c3cb8fea2ff75c0265b8d844c5edc7eaef54136160c6090750e562059b957355b15",
140+
... )
141+
... async def files_in_dffml_commit_152c2b(dffml_dir):
142+
... return len(list(dffml_dir.rglob("**/*")))
143+
>>>
144+
>>> asyncio.run(files_in_dffml_commit_152c2b())
145+
594
146+
"""
147+
directory_path = pathlib.Path(directory_path)
148+
149+
async def extractor(download_path):
150+
download_path = download_path.absolute()
151+
with chdir(directory_path):
152+
shutil.unpack_archive(str(download_path), ".")
153+
154+
extract = cached_download(
155+
url, file_path, expected_hash, protocol_allowlist=protocol_allowlist,
156+
)(extractor)
157+
158+
def mkwrapper(func):
159+
@functools.wraps(func)
160+
async def wrapper(*args, **kwds):
161+
if not directory_path.is_dir():
162+
directory_path.mkdir(parents=True)
163+
await extract()
164+
return await func(*(list(args) + [directory_path]), **kwds)
165+
166+
return wrapper
167+
168+
return mkwrapper

docs/api/util/index.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
API Helper Utilities Reference
2-
==============================
1+
Utilities
2+
=========
33

44
:py:mod:`asyncio` testing, command line, and configuration helpers live here.
55

@@ -9,3 +9,4 @@ API Helper Utilities Reference
99
:caption: Contents:
1010

1111
asynchelper
12+
net

docs/api/util/net.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Networking Helpers
2+
==================
3+
4+
.. automodule:: dffml.util.net
5+
:members:

docs/conf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,23 @@
8888
napoleon_numpy_docstring = True
8989

9090
doctest_global_setup = """
91+
import os
9192
import sys
93+
import shutil
94+
import atexit
9295
import inspect
9396
import asyncio
97+
import tempfile
98+
import functools
99+
100+
# Create a temporary directory for test to run in
101+
DOCTEST_TEMPDIR = tempfile.mkdtemp()
102+
# Remove it when the test exits
103+
atexit.register(functools.partial(shutil.rmtree, DOCTEST_TEMPDIR))
104+
# Change the current working directory to the temporary directory
105+
os.chdir(DOCTEST_TEMPDIR)
94106
95107
from dffml.base import *
96108
from dffml.df.base import *
109+
from dffml.util.net import *
97110
"""

scripts/docs.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,5 @@ ln -s "${PWD}/service/http/docs/" docs/plugins/service/http
1010
python3.7 scripts/docs.py
1111
python3.7 -c 'import os, pkg_resources; [e.load() for e in pkg_resources.iter_entry_points("console_scripts") if e.name.startswith("sphinx-build")][0]()' -b html docs pages \
1212
|| (echo "[ERROR] Failed run sphinx, is it installed (pip install -U .[dev])?" 1>&2 ; exit 1)
13-
find pages/ -name \*.html -exec \
14-
sed -i 's/<span class="gp">\&gt;\&gt;\&gt; <\/span>//g' {} \;
15-
find pages/ -name \*.html -exec \
16-
sed -i 's/<span class="go">\&gt;\&gt;\&gt;<\/span>//g' {} \;
1713
cp -r docs/images pages/
1814
touch pages/.nojekyll

tests/source/test_idx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TestIDXSources(AsyncTestCase):
3131
async def test_idx1(self, filename):
3232
feature_name = "label"
3333
async with IDX1Source(
34-
IDX1SourceConfig(filename=filename, feature=feature_name)
34+
IDX1SourceConfig(filename=str(filename), feature=feature_name)
3535
) as source:
3636
async with source() as sctx:
3737
records = [record async for record in sctx.records()]
@@ -43,7 +43,7 @@ async def test_idx1(self, filename):
4343
async def test_idx3(self, filename):
4444
feature_name = "image"
4545
async with IDX3Source(
46-
IDX3SourceConfig(filename=filename, feature=feature_name)
46+
IDX3SourceConfig(filename=str(filename), feature=feature_name)
4747
) as source:
4848
async with source() as sctx:
4949
records = [record async for record in sctx.records()]

tests/util/test_net.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import httptest
1111

1212
from dffml.util.os import chdir
13-
from dffml.util.net import cached_download
13+
from dffml.util.net import cached_download, cached_download_unpack_archive
1414
from dffml.util.asynctestcase import AsyncTestCase
1515

1616

@@ -44,8 +44,17 @@ def do_GET(self):
4444

4545

4646
class TestNet(AsyncTestCase):
47+
def verify_extracted_contents(self, extracted):
48+
self.assertTrue((extracted / "somedir").is_dir())
49+
self.assertEqual(
50+
(extracted / "somedir" / "hello.txt").read_text(), "world"
51+
)
52+
self.assertEqual(
53+
(extracted / "somedir" / "dead.bin").read_bytes(), b"\xBE\xEF"
54+
)
55+
4756
@httptest.Server(TestCachedDownloadServer)
48-
async def test_call_response(self, ts=httptest.NoServer()):
57+
async def test_cached_download(self, ts=httptest.NoServer()):
4958
with tempfile.TemporaryDirectory() as tempdir:
5059

5160
@cached_download(
@@ -63,11 +72,22 @@ async def func(filename):
6372
# Unpack the archive
6473
shutil.unpack_archive(await func(), extracted)
6574

66-
# Verify contents are correct
67-
self.assertTrue((extracted / "somedir").is_dir())
68-
self.assertEqual(
69-
(extracted / "somedir" / "hello.txt").read_text(), "world"
70-
)
71-
self.assertEqual(
72-
(extracted / "somedir" / "dead.bin").read_bytes(), b"\xBE\xEF"
75+
self.verify_extracted_contents(extracted)
76+
77+
@httptest.Server(TestCachedDownloadServer)
78+
async def test_cached_download_unpack_archive(
79+
self, ts=httptest.NoServer()
80+
):
81+
with tempfile.TemporaryDirectory() as tempdir:
82+
83+
@cached_download_unpack_archive(
84+
ts.url() + "/archive.tar.gz",
85+
pathlib.Path(tempdir) / "archive.tar.gz",
86+
pathlib.Path(tempdir) / "archive",
87+
"cd538a17ce51458e3315639eba0650e96740d3d6abadbf174209ee7c5cae000ac739e99d9f32c9c2ba417b0cf67e69b8",
88+
protocol_allowlist=["http://"],
7389
)
90+
async def func(extracted):
91+
return extracted
92+
93+
self.verify_extracted_contents(await func())

0 commit comments

Comments
 (0)