Skip to content

Commit 47e7959

Browse files
committed
Expand caching to cache remote refs
Several refinements are needed in the CacheDownloader to support this. Primarily, support for sibling directories to the `downloads` dir, in the cache dir. This allows the ref resolver to pass in `"refs"` as a directory name. As a related change in the course of this work, HTTP retries are expanded in scope to also cover connection errors and timeouts. Additionally, `disable_cache` gets passed down from the CLI through to the ref resolution layer. Tests are enhanced to better explore CacheDownloader behaviors, but not to test the usage in ref resolution.
1 parent 47a85d7 commit 47e7959

File tree

9 files changed

+211
-139
lines changed

9 files changed

+211
-139
lines changed

src/check_jsonschema/cachedownloader.py

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,10 @@
1111

1212
import requests
1313

14-
# this will let us do any other caching we might need in the future in the same
15-
# cache dir (adjacent to "downloads")
16-
_CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads")
17-
1814
_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
1915

2016

21-
def _get_default_cache_dir() -> str | None:
17+
def _base_cache_dir() -> str | None:
2218
sysname = platform.system()
2319

2420
# on windows, try to get the appdata env var
@@ -34,9 +30,13 @@ def _get_default_cache_dir() -> str | None:
3430
else:
3531
cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
3632

37-
if cache_dir:
38-
cache_dir = os.path.join(cache_dir, _CACHEDIR_NAME)
33+
return cache_dir
3934

35+
36+
def _resolve_cache_dir(dirname: str = "downloads") -> str | None:
37+
cache_dir = _base_cache_dir()
38+
if cache_dir:
39+
cache_dir = os.path.join(cache_dir, "check_jsonschema", dirname)
4040
return cache_dir
4141

4242

@@ -55,18 +55,21 @@ def _lastmod_from_response(response: requests.Response) -> float:
5555
def _get_request(
5656
file_url: str, *, response_ok: t.Callable[[requests.Response], bool]
5757
) -> requests.Response:
58-
try:
59-
r: requests.Response | None = None
60-
for _attempt in range(3):
58+
num_retries = 2
59+
r: requests.Response | None = None
60+
for _attempt in range(num_retries + 1):
61+
try:
6162
r = requests.get(file_url, stream=True)
62-
if r.ok and response_ok(r):
63-
return r
64-
assert r is not None
65-
raise FailedDownloadError(
66-
f"got response with status={r.status_code}, retries exhausted"
67-
)
68-
except requests.RequestException as e:
69-
raise FailedDownloadError("encountered error during download") from e
63+
except requests.RequestException as e:
64+
if _attempt == num_retries:
65+
raise FailedDownloadError("encountered error during download") from e
66+
continue
67+
if r.ok and response_ok(r):
68+
return r
69+
assert r is not None
70+
raise FailedDownloadError(
71+
f"got response with status={r.status_code}, retries exhausted"
72+
)
7073

7174

7275
def _atomic_write(dest: str, content: bytes) -> None:
@@ -97,27 +100,19 @@ class FailedDownloadError(Exception):
97100

98101

99102
class CacheDownloader:
100-
def __init__(
101-
self,
102-
cache_dir: str | None = None,
103-
disable_cache: bool = False,
104-
validation_callback: t.Callable[[bytes], t.Any] | None = None,
105-
):
106-
self._cache_dir = cache_dir or _get_default_cache_dir()
103+
def __init__(self, cache_dir: str | None = None, disable_cache: bool = False):
104+
if cache_dir is None:
105+
self._cache_dir = _resolve_cache_dir()
106+
else:
107+
self._cache_dir = _resolve_cache_dir(cache_dir)
107108
self._disable_cache = disable_cache
108-
self._validation_callback = validation_callback
109-
110-
def _validate(self, response: requests.Response) -> bool:
111-
if not self._validation_callback:
112-
return True
113-
114-
try:
115-
self._validation_callback(response.content)
116-
return True
117-
except ValueError:
118-
return False
119109

120-
def _download(self, file_url: str, filename: str) -> str:
110+
def _download(
111+
self,
112+
file_url: str,
113+
filename: str,
114+
response_ok: t.Callable[[requests.Response], bool],
115+
) -> str:
121116
assert self._cache_dir is not None
122117
os.makedirs(self._cache_dir, exist_ok=True)
123118
dest = os.path.join(self._cache_dir, filename)
@@ -129,7 +124,7 @@ def check_response_for_download(r: requests.Response) -> bool:
129124
if _cache_hit(dest, r):
130125
return True
131126
# we now know it's not a hit, so validate the content (forces download)
132-
return self._validate(r)
127+
return response_ok(r)
133128

134129
response = _get_request(file_url, response_ok=check_response_for_download)
135130
# check to see if we have a file which matches the connection
@@ -140,15 +135,31 @@ def check_response_for_download(r: requests.Response) -> bool:
140135
return dest
141136

142137
@contextlib.contextmanager
143-
def open(self, file_url: str, filename: str) -> t.Iterator[t.IO[bytes]]:
138+
def open(
139+
self,
140+
file_url: str,
141+
filename: str,
142+
validate_response: t.Callable[[requests.Response], bool],
143+
) -> t.Iterator[t.IO[bytes]]:
144144
if (not self._cache_dir) or self._disable_cache:
145-
yield io.BytesIO(_get_request(file_url, response_ok=self._validate).content)
145+
yield io.BytesIO(
146+
_get_request(file_url, response_ok=validate_response).content
147+
)
146148
else:
147-
with open(self._download(file_url, filename), "rb") as fp:
149+
with open(
150+
self._download(file_url, filename, response_ok=validate_response), "rb"
151+
) as fp:
148152
yield fp
149153

150-
def bind(self, file_url: str, filename: str | None = None) -> BoundCacheDownloader:
151-
return BoundCacheDownloader(file_url, filename, self)
154+
def bind(
155+
self,
156+
file_url: str,
157+
filename: str | None = None,
158+
validation_callback: t.Callable[[bytes], t.Any] | None = None,
159+
) -> BoundCacheDownloader:
160+
return BoundCacheDownloader(
161+
file_url, filename, self, validation_callback=validation_callback
162+
)
152163

153164

154165
class BoundCacheDownloader:
@@ -157,12 +168,29 @@ def __init__(
157168
file_url: str,
158169
filename: str | None,
159170
downloader: CacheDownloader,
171+
*,
172+
validation_callback: t.Callable[[bytes], t.Any] | None = None,
160173
):
161174
self._file_url = file_url
162175
self._filename = filename or file_url.split("/")[-1]
163176
self._downloader = downloader
177+
self._validation_callback = validation_callback
164178

165179
@contextlib.contextmanager
166180
def open(self) -> t.Iterator[t.IO[bytes]]:
167-
with self._downloader.open(self._file_url, self._filename) as fp:
181+
with self._downloader.open(
182+
self._file_url,
183+
self._filename,
184+
validate_response=self._validate_response,
185+
) as fp:
168186
yield fp
187+
188+
def _validate_response(self, response: requests.Response) -> bool:
189+
if not self._validation_callback:
190+
return True
191+
192+
try:
193+
self._validation_callback(response.content)
194+
return True
195+
except ValueError:
196+
return False

src/check_jsonschema/cli/main_command.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ def build_schema_loader(args: ParseResult) -> SchemaLoaderBase:
300300
assert args.schema_path is not None
301301
return SchemaLoader(
302302
args.schema_path,
303-
args.cache_filename,
304-
args.disable_cache,
303+
cache_filename=args.cache_filename,
304+
disable_cache=args.disable_cache,
305305
base_uri=args.base_uri,
306306
validator_class=args.validator_class,
307307
)

src/check_jsonschema/schema_loader/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ def get_validator(
5757

5858
class SchemaLoader(SchemaLoaderBase):
5959
validator_class: type[jsonschema.protocols.Validator] | None = None
60+
disable_cache: bool = True
6061

6162
def __init__(
6263
self,
6364
schemafile: str,
65+
*,
6466
cache_filename: str | None = None,
65-
disable_cache: bool = False,
6667
base_uri: str | None = None,
6768
validator_class: type[jsonschema.protocols.Validator] | None = None,
69+
disable_cache: bool = True,
6870
) -> None:
6971
# record input parameters (these are not to be modified)
7072
self.schemafile = schemafile
@@ -140,7 +142,7 @@ def get_validator(
140142
# reference resolution
141143
# with support for YAML, TOML, and other formats from the parsers
142144
reference_registry = make_reference_registry(
143-
self._parsers, retrieval_uri, schema
145+
self._parsers, retrieval_uri, schema, self.disable_cache
144146
)
145147

146148
if self.validator_class is None:
@@ -171,7 +173,7 @@ def get_validator(
171173

172174

173175
class BuiltinSchemaLoader(SchemaLoader):
174-
def __init__(self, schema_name: str, base_uri: str | None = None) -> None:
176+
def __init__(self, schema_name: str, *, base_uri: str | None = None) -> None:
175177
self.schema_name = schema_name
176178
self.base_uri = base_uri
177179
self._parsers = ParserSet()
@@ -187,7 +189,7 @@ def get_schema(self) -> dict[str, t.Any]:
187189

188190

189191
class MetaSchemaLoader(SchemaLoaderBase):
190-
def __init__(self, base_uri: str | None = None) -> None:
192+
def __init__(self, *, base_uri: str | None = None) -> None:
191193
if base_uri is not None:
192194
raise NotImplementedError(
193195
"'--base-uri' was used with '--metaschema'. "

src/check_jsonschema/schema_loader/readers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def __init__(
8080
self.parsers = ParserSet()
8181
self.downloader = CacheDownloader(
8282
disable_cache=disable_cache,
83-
validation_callback=self._parse,
84-
).bind(url, cache_filename)
83+
).bind(url, cache_filename, validation_callback=self._parse)
8584
self._parsed_schema: dict | _UnsetType = _UNSET
8685

8786
def _parse(self, schema_bytes: bytes) -> t.Any:

src/check_jsonschema/schema_loader/resolver.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import typing as t
45
import urllib.parse
56

67
import referencing
7-
import requests
88
from referencing.jsonschema import DRAFT202012, Schema
99

10+
from ..cachedownloader import CacheDownloader
1011
from ..parsers import ParserSet
1112
from ..utils import filename2path
1213

1314

1415
def make_reference_registry(
15-
parsers: ParserSet, retrieval_uri: str | None, schema: dict
16+
parsers: ParserSet, retrieval_uri: str | None, schema: dict, disable_cache: bool
1617
) -> referencing.Registry:
1718
id_attribute_: t.Any = schema.get("$id")
1819
if isinstance(id_attribute_, str):
@@ -26,7 +27,9 @@ def make_reference_registry(
2627
# mypy does not recognize that Registry is an `attrs` class and has `retrieve` as an
2728
# argument to its implicit initializer
2829
registry: referencing.Registry = referencing.Registry( # type: ignore[call-arg]
29-
retrieve=create_retrieve_callable(parsers, retrieval_uri, id_attribute)
30+
retrieve=create_retrieve_callable(
31+
parsers, retrieval_uri, id_attribute, disable_cache
32+
)
3033
)
3134

3235
if retrieval_uri is not None:
@@ -38,13 +41,17 @@ def make_reference_registry(
3841

3942

4043
def create_retrieve_callable(
41-
parser_set: ParserSet, retrieval_uri: str | None, id_attribute: str | None
44+
parser_set: ParserSet,
45+
retrieval_uri: str | None,
46+
id_attribute: str | None,
47+
disable_cache: bool,
4248
) -> t.Callable[[str], referencing.Resource[Schema]]:
4349
base_uri = id_attribute
4450
if base_uri is None:
4551
base_uri = retrieval_uri
4652

4753
cache = ResourceCache()
54+
downloader = CacheDownloader("refs", disable_cache)
4855

4956
def get_local_file(uri: str) -> t.Any:
5057
path = filename2path(uri)
@@ -62,10 +69,19 @@ def retrieve_reference(uri: str) -> referencing.Resource[Schema]:
6269

6370
full_uri_scheme = urllib.parse.urlsplit(full_uri).scheme
6471
if full_uri_scheme in ("http", "https"):
65-
data = requests.get(full_uri, stream=True)
66-
parsed_object = parser_set.parse_data_with_path(
67-
data.content, full_uri, "json"
72+
73+
def validation_callback(content: bytes) -> None:
74+
parser_set.parse_data_with_path(content, full_uri, "json")
75+
76+
bound_downloader = downloader.bind(
77+
full_uri,
78+
hashlib.md5(full_uri.encode()).hexdigest(),
79+
validation_callback,
6880
)
81+
with bound_downloader.open() as fp:
82+
data = fp.read()
83+
84+
parsed_object = parser_set.parse_data_with_path(data, full_uri, "json")
6985
else:
7086
parsed_object = get_local_file(full_uri)
7187

tests/acceptance/test_remote_ref_resolution.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,6 @@
3535
}
3636

3737

38-
@pytest.fixture(autouse=True)
39-
def _mock_schema_cache_dir(monkeypatch, tmp_path):
40-
def _fake_default_cache_dir():
41-
return str(tmp_path)
42-
43-
monkeypatch.setattr(
44-
"check_jsonschema.cachedownloader._get_default_cache_dir",
45-
_fake_default_cache_dir,
46-
)
47-
48-
4938
@pytest.mark.parametrize("check_passes", (True, False))
5039
@pytest.mark.parametrize("casename", ("case1", "case2"))
5140
def test_remote_ref_resolution_simple_case(run_line, check_passes, casename, tmp_path):

tests/acceptance/test_special_filetypes.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,11 @@ def test_schema_and_instance_in_fifos(tmp_path, run_line, check_succeeds):
7979

8080

8181
@pytest.mark.parametrize("check_passes", (True, False))
82-
def test_remote_schema_requiring_retry(run_line, check_passes, tmp_path, monkeypatch):
82+
def test_remote_schema_requiring_retry(run_line, check_passes, tmp_path):
8383
"""
8484
a "remote schema" (meaning HTTPS) with bad data, therefore requiring that a retry
8585
fires in order to parse
8686
"""
87-
88-
def _fake_default_cache_dir():
89-
return str(tmp_path)
90-
91-
monkeypatch.setattr(
92-
"check_jsonschema.cachedownloader._get_default_cache_dir",
93-
_fake_default_cache_dir,
94-
)
95-
9687
schema_loc = "https://example.com/schema1.json"
9788
responses.add("GET", schema_loc, body="", match_querystring=None)
9889
responses.add(

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,17 @@ def in_tmp_dir(request, tmp_path):
4646
os.chdir(str(tmp_path))
4747
yield
4848
os.chdir(request.config.invocation_dir)
49+
50+
51+
@pytest.fixture
52+
def cache_dir(tmp_path):
53+
return tmp_path / ".cache"
54+
55+
56+
@pytest.fixture(autouse=True)
57+
def patch_cache_dir(monkeypatch, cache_dir):
58+
with monkeypatch.context() as m:
59+
m.setattr(
60+
"check_jsonschema.cachedownloader._base_cache_dir", lambda: str(cache_dir)
61+
)
62+
yield m

0 commit comments

Comments
 (0)