Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit b78b035

Browse files
committed
tests: util: net: Add cached_download test
Signed-off-by: John Andersen <[email protected]>
1 parent 87f2061 commit b78b035

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
### Added
99
- Tensorflow hub NLP models.
1010
- Notes on development dependencies in `setup.py` files to codebase notes.
11+
- Test for `cached_download`
1112
### Changed
1213
- Definitions with a `spec` can use the `subspec` parameter to declare that they
1314
are a list or a dict where the values are of the `spec` type. Rather than the

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"sphinx_rtd_theme",
6666
],
6767
},
68+
tests_require=["httptest>=0.0.15",],
6869
entry_points={
6970
"console_scripts": ["dffml = dffml.cli.cli:CLI.main"],
7071
"dffml.source": [

tests/util/test_net.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import io
2+
import os
3+
import shutil
4+
import tarfile
5+
import pathlib
6+
import tempfile
7+
import contextlib
8+
import unittest.mock
9+
10+
import httptest
11+
12+
from dffml.util.os import chdir
13+
from dffml.util.net import cached_download
14+
from dffml.util.asynctestcase import AsyncTestCase
15+
16+
17+
class TestCachedDownloadServer(httptest.Handler):
18+
def do_GET(self):
19+
self.send_response(200)
20+
self.send_header("Content-type", "application/x-gzip")
21+
self.end_headers()
22+
23+
with contextlib.ExitStack() as stack:
24+
# gzip will add a last modified time by calling time.time, to ensure
25+
# that the hash is always the same, we set the time to 0
26+
stack.enter_context(
27+
unittest.mock.patch("time.time", return_value=1)
28+
)
29+
# Create the bytes objects to build the tarfile in memory
30+
tar_fileobj = stack.enter_context(io.BytesIO())
31+
hello_txt_fileobj = stack.enter_context(io.BytesIO(b"world"))
32+
dead_bin_fileobj = stack.enter_context(io.BytesIO(b"\xBE\xEF"))
33+
# Create the TarInfo objects
34+
hello_txt_tarinfo = tarfile.TarInfo(name="somedir/hello.txt")
35+
hello_txt_tarinfo.size = len(hello_txt_fileobj.getvalue())
36+
dead_bin_tarinfo = tarfile.TarInfo(name="somedir/dead.bin")
37+
dead_bin_tarinfo.size = len(dead_bin_fileobj.getvalue())
38+
# Create the archive using the bytes objects
39+
with tarfile.open(mode="w|gz", fileobj=tar_fileobj) as archive:
40+
archive.addfile(hello_txt_tarinfo, fileobj=hello_txt_fileobj)
41+
archive.addfile(dead_bin_tarinfo, fileobj=dead_bin_fileobj)
42+
# Write out the contents of the tar to the client
43+
self.wfile.write(tar_fileobj.getvalue())
44+
45+
46+
class TestNet(AsyncTestCase):
47+
@httptest.Server(TestCachedDownloadServer)
48+
async def test_call_response(self, ts=httptest.NoServer()):
49+
with tempfile.TemporaryDirectory() as tempdir:
50+
51+
@cached_download(
52+
ts.url() + "/archive.tar.gz",
53+
pathlib.Path(tempdir) / "archive.tar.gz",
54+
"cd538a17ce51458e3315639eba0650e96740d3d6abadbf174209ee7c5cae000ac739e99d9f32c9c2ba417b0cf67e69b8",
55+
protocol_allowlist=["http://"],
56+
)
57+
async def func(filename):
58+
return filename
59+
60+
# Directory to extract to
61+
extracted = pathlib.Path(tempdir, "extracted")
62+
63+
# Unpack the archive
64+
shutil.unpack_archive(await func(), extracted)
65+
66+
# Verify contents are correct
67+
self.assertTrue((extracted / "somedir").is_dir())
68+
self.assertEqual(
69+
(extracted / "somedir" / "hello.txt").read_text(), "world"
70+
)
71+
self.assertEqual(
72+
(extracted / "somedir" / "dead.bin").read_bytes(), b"\xBE\xEF"
73+
)

0 commit comments

Comments
 (0)