Skip to content

Commit 47a85d7

Browse files
committed
Refactor the cachedownloader to bound/unbound
Separate the existing singular downloader into two distinct objects: the bound and unbound variants. An unbound downloader implements the core logic, almost to completion. A bound downloader *contains* an unbound one and adds a known file target (remote and local names to use). The two are tied together via a single method: CacheDownloader.bind(URI, name) -> BoundCacheDownloader The result allows for a CacheDownloader to be built and then bound multiple times.
1 parent 53793e2 commit 47a85d7

File tree

5 files changed

+143
-133
lines changed

5 files changed

+143
-133
lines changed

src/check_jsonschema/cachedownloader.py

Lines changed: 111 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -11,103 +11,102 @@
1111

1212
import requests
1313

14+
# this will let us do any other caching we might need in the future in the same
15+
# cache dir (adjacent to "downloads")
16+
_CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads")
17+
18+
_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
19+
20+
21+
def _get_default_cache_dir() -> str | None:
22+
sysname = platform.system()
23+
24+
# on windows, try to get the appdata env var
25+
# this *could* result in cache_dir=None, which is fine, just skip caching in
26+
# that case
27+
if sysname == "Windows":
28+
cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA"))
29+
# macOS -> app support dir
30+
elif sysname == "Darwin":
31+
cache_dir = os.path.expanduser("~/Library/Caches")
32+
# default for unknown platforms, namely linux behavior
33+
# use XDG env var and default to ~/.cache/
34+
else:
35+
cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
36+
37+
if cache_dir:
38+
cache_dir = os.path.join(cache_dir, _CACHEDIR_NAME)
39+
40+
return cache_dir
41+
42+
43+
def _lastmod_from_response(response: requests.Response) -> float:
44+
try:
45+
return time.mktime(
46+
time.strptime(response.headers["last-modified"], _LASTMOD_FMT)
47+
)
48+
# OverflowError: time outside of platform-specific bounds
49+
# ValueError: malformed/unparseable
50+
# LookupError: no such header
51+
except (OverflowError, ValueError, LookupError):
52+
return 0.0
53+
54+
55+
def _get_request(
56+
file_url: str, *, response_ok: t.Callable[[requests.Response], bool]
57+
) -> requests.Response:
58+
try:
59+
r: requests.Response | None = None
60+
for _attempt in range(3):
61+
r = requests.get(file_url, stream=True)
62+
if r.ok and response_ok(r):
63+
return r
64+
assert r is not None
65+
raise FailedDownloadError(
66+
f"got response with status={r.status_code}, retries exhausted"
67+
)
68+
except requests.RequestException as e:
69+
raise FailedDownloadError("encountered error during download") from e
70+
71+
72+
def _atomic_write(dest: str, content: bytes) -> None:
73+
# download to a temp file and then move to the dest
74+
# this makes the download safe if run in parallel (parallel runs
75+
# won't create a new empty file for writing and cause failures)
76+
fp = tempfile.NamedTemporaryFile(mode="wb", delete=False)
77+
fp.write(content)
78+
fp.close()
79+
shutil.copy(fp.name, dest)
80+
os.remove(fp.name)
81+
82+
83+
def _cache_hit(cachefile: str, response: requests.Response) -> bool:
84+
# no file? miss
85+
if not os.path.exists(cachefile):
86+
return False
87+
88+
# compare mtime on any cached file against the remote last-modified time
89+
# it is considered a hit if the local file is at least as new as the remote file
90+
local_mtime = os.path.getmtime(cachefile)
91+
remote_mtime = _lastmod_from_response(response)
92+
return local_mtime >= remote_mtime
93+
1494

1595
class FailedDownloadError(Exception):
1696
pass
1797

1898

1999
class CacheDownloader:
20-
_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
21-
22-
# changed in v0.5.0
23-
# original cache dir was "jsonschema_validate"
24-
# this will let us do any other caching we might need in the future in the same
25-
# cache dir (adjacent to "downloads")
26-
_CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads")
27-
28100
def __init__(
29101
self,
30-
file_url: str,
31-
filename: str | None = None,
32102
cache_dir: str | None = None,
33103
disable_cache: bool = False,
34104
validation_callback: t.Callable[[bytes], t.Any] | None = None,
35105
):
36-
self._file_url = file_url
37-
self._filename = filename or file_url.split("/")[-1]
38-
self._cache_dir = cache_dir or self._compute_default_cache_dir()
106+
self._cache_dir = cache_dir or _get_default_cache_dir()
39107
self._disable_cache = disable_cache
40108
self._validation_callback = validation_callback
41109

42-
def _compute_default_cache_dir(self) -> str | None:
43-
sysname = platform.system()
44-
45-
# on windows, try to get the appdata env var
46-
# this *could* result in cache_dir=None, which is fine, just skip caching in
47-
# that case
48-
if sysname == "Windows":
49-
cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA"))
50-
# macOS -> app support dir
51-
elif sysname == "Darwin":
52-
cache_dir = os.path.expanduser("~/Library/Caches")
53-
# default for unknown platforms, namely linux behavior
54-
# use XDG env var and default to ~/.cache/
55-
else:
56-
cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
57-
58-
if cache_dir:
59-
cache_dir = os.path.join(cache_dir, self._CACHEDIR_NAME)
60-
61-
return cache_dir
62-
63-
def _get_request(
64-
self, *, response_ok: t.Callable[[requests.Response], bool]
65-
) -> requests.Response:
66-
try:
67-
r: requests.Response | None = None
68-
for _attempt in range(3):
69-
r = requests.get(self._file_url, stream=True)
70-
if r.ok and response_ok(r):
71-
return r
72-
assert r is not None
73-
raise FailedDownloadError(
74-
f"got response with status={r.status_code}, retries exhausted"
75-
)
76-
except requests.RequestException as e:
77-
raise FailedDownloadError("encountered error during download") from e
78-
79-
def _lastmod_from_response(self, response: requests.Response) -> float:
80-
try:
81-
return time.mktime(
82-
time.strptime(response.headers["last-modified"], self._LASTMOD_FMT)
83-
)
84-
# OverflowError: time outside of platform-specific bounds
85-
# ValueError: malformed/unparseable
86-
# LookupError: no such header
87-
except (OverflowError, ValueError, LookupError):
88-
return 0.0
89-
90-
def _cache_hit(self, cachefile: str, response: requests.Response) -> bool:
91-
# no file? miss
92-
if not os.path.exists(cachefile):
93-
return False
94-
95-
# compare mtime on any cached file against the remote last-modified time
96-
# it is considered a hit if the local file is at least as new as the remote file
97-
local_mtime = os.path.getmtime(cachefile)
98-
remote_mtime = self._lastmod_from_response(response)
99-
return local_mtime >= remote_mtime
100-
101-
def _write(self, dest: str, response: requests.Response) -> None:
102-
# download to a temp file and then move to the dest
103-
# this makes the download safe if run in parallel (parallel runs
104-
# won't create a new empty file for writing and cause failures)
105-
fp = tempfile.NamedTemporaryFile(mode="wb", delete=False)
106-
fp.write(response.content)
107-
fp.close()
108-
shutil.copy(fp.name, dest)
109-
os.remove(fp.name)
110-
111110
def _validate(self, response: requests.Response) -> bool:
112111
if not self._validation_callback:
113112
return True
@@ -118,32 +117,52 @@ def _validate(self, response: requests.Response) -> bool:
118117
except ValueError:
119118
return False
120119

121-
def _download(self) -> str:
122-
assert self._cache_dir
120+
def _download(self, file_url: str, filename: str) -> str:
121+
assert self._cache_dir is not None
123122
os.makedirs(self._cache_dir, exist_ok=True)
124-
dest = os.path.join(self._cache_dir, self._filename)
123+
dest = os.path.join(self._cache_dir, filename)
125124

126125
def check_response_for_download(r: requests.Response) -> bool:
127126
# if the response indicates a cache hit, treat it as valid
128127
# this ensures that we short-circuit any further evaluation immediately on
129128
# a hit
130-
if self._cache_hit(dest, r):
129+
if _cache_hit(dest, r):
131130
return True
132131
# we now know it's not a hit, so validate the content (forces download)
133132
return self._validate(r)
134133

135-
response = self._get_request(response_ok=check_response_for_download)
134+
response = _get_request(file_url, response_ok=check_response_for_download)
136135
# check to see if we have a file which matches the connection
137136
# only download if we do not (cache miss, vs hit)
138-
if not self._cache_hit(dest, response):
139-
self._write(dest, response)
137+
if not _cache_hit(dest, response):
138+
_atomic_write(dest, response.content)
140139

141140
return dest
142141

143142
@contextlib.contextmanager
144-
def open(self) -> t.Iterator[t.IO[bytes]]:
143+
def open(self, file_url: str, filename: str) -> t.Iterator[t.IO[bytes]]:
145144
if (not self._cache_dir) or self._disable_cache:
146-
yield io.BytesIO(self._get_request(response_ok=self._validate).content)
145+
yield io.BytesIO(_get_request(file_url, response_ok=self._validate).content)
147146
else:
148-
with open(self._download(), "rb") as fp:
147+
with open(self._download(file_url, filename), "rb") as fp:
149148
yield fp
149+
150+
def bind(self, file_url: str, filename: str | None = None) -> BoundCacheDownloader:
151+
return BoundCacheDownloader(file_url, filename, self)
152+
153+
154+
class BoundCacheDownloader:
155+
def __init__(
156+
self,
157+
file_url: str,
158+
filename: str | None,
159+
downloader: CacheDownloader,
160+
):
161+
self._file_url = file_url
162+
self._filename = filename or file_url.split("/")[-1]
163+
self._downloader = downloader
164+
165+
@contextlib.contextmanager
166+
def open(self) -> t.Iterator[t.IO[bytes]]:
167+
with self._downloader.open(self._file_url, self._filename) as fp:
168+
yield fp

src/check_jsonschema/schema_loader/readers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,9 @@ def __init__(
7979
self.url = url
8080
self.parsers = ParserSet()
8181
self.downloader = CacheDownloader(
82-
url,
83-
cache_filename,
8482
disable_cache=disable_cache,
8583
validation_callback=self._parse,
86-
)
84+
).bind(url, cache_filename)
8785
self._parsed_schema: dict | _UnsetType = _UNSET
8886

8987
def _parse(self, schema_bytes: bytes) -> t.Any:

tests/acceptance/test_remote_ref_resolution.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import pytest
44
import responses
55

6-
from check_jsonschema import cachedownloader
7-
86
CASES = {
97
"case1": {
108
"main_schema": {
@@ -39,13 +37,12 @@
3937

4038
@pytest.fixture(autouse=True)
4139
def _mock_schema_cache_dir(monkeypatch, tmp_path):
42-
def _fake_compute_default_cache_dir(self):
40+
def _fake_default_cache_dir():
4341
return str(tmp_path)
4442

4543
monkeypatch.setattr(
46-
cachedownloader.CacheDownloader,
47-
"_compute_default_cache_dir",
48-
_fake_compute_default_cache_dir,
44+
"check_jsonschema.cachedownloader._get_default_cache_dir",
45+
_fake_default_cache_dir,
4946
)
5047

5148

tests/acceptance/test_special_filetypes.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import pytest
77
import responses
88

9-
from check_jsonschema import cachedownloader
10-
119

1210
@pytest.mark.skipif(
1311
platform.system() != "Linux", reason="test requires /proc/self/ mechanism"
@@ -87,13 +85,12 @@ def test_remote_schema_requiring_retry(run_line, check_passes, tmp_path, monkeyp
8785
fires in order to parse
8886
"""
8987

90-
def _fake_compute_default_cache_dir(self):
88+
def _fake_default_cache_dir():
9189
return str(tmp_path)
9290

9391
monkeypatch.setattr(
94-
cachedownloader.CacheDownloader,
95-
"_compute_default_cache_dir",
96-
_fake_compute_default_cache_dir,
92+
"check_jsonschema.cachedownloader._get_default_cache_dir",
93+
_fake_default_cache_dir,
9794
)
9895

9996
schema_loc = "https://example.com/schema1.json"

0 commit comments

Comments
 (0)