@@ -60,23 +60,18 @@ def _compute_default_cache_dir(self) -> str | None:
60
60
61
61
return cache_dir
62
62
63
- def _get_request (self ) -> requests .Response :
63
+ def _get_request (
64
+ self , * , response_ok : t .Callable [[requests .Response ], bool ]
65
+ ) -> requests .Response :
64
66
try :
65
- # do manual retries, rather than using urllib3 retries, to make it trivially
66
- # testable with 'responses'
67
67
r : requests .Response | None = None
68
68
for _attempt in range (3 ):
69
69
r = requests .get (self ._file_url , stream = True )
70
- if r .ok :
71
- if self ._validation_callback is not None :
72
- try :
73
- self ._validation_callback (r .content )
74
- except ValueError :
75
- continue
70
+ if r .ok and response_ok (r ):
76
71
return r
77
72
assert r is not None
78
73
raise FailedDownloadError (
79
- f"got responses with status={ r .status_code } , retries exhausted"
74
+ f"got response with status={ r .status_code } , retries exhausted"
80
75
)
81
76
except requests .RequestException as e :
82
77
raise FailedDownloadError ("encountered error during download" ) from e
@@ -113,12 +108,31 @@ def _write(self, dest: str, response: requests.Response) -> None:
113
108
shutil .copy (fp .name , dest )
114
109
os .remove (fp .name )
115
110
111
+ def _validate (self , response : requests .Response ) -> bool :
112
+ if not self ._validation_callback :
113
+ return True
114
+
115
+ try :
116
+ self ._validation_callback (response .content )
117
+ return True
118
+ except ValueError :
119
+ return False
120
+
116
121
def _download (self ) -> str :
117
122
assert self ._cache_dir
118
123
os .makedirs (self ._cache_dir , exist_ok = True )
119
124
dest = os .path .join (self ._cache_dir , self ._filename )
120
125
121
- response = self ._get_request ()
126
+ def check_response_for_download (r : requests .Response ) -> bool :
127
+ # if the response indicates a cache hit, treat it as valid
128
+ # this ensures that we short-circuit any further evaluation immediately on
129
+ # a hit
130
+ if self ._cache_hit (dest , r ):
131
+ return True
132
+ # we now know it's not a hit, so validate the content (forces download)
133
+ return self ._validate (r )
134
+
135
+ response = self ._get_request (response_ok = check_response_for_download )
122
136
# check to see if we have a file which matches the connection
123
137
# only download if we do not (cache miss, vs hit)
124
138
if not self ._cache_hit (dest , response ):
@@ -129,7 +143,7 @@ def _download(self) -> str:
129
143
@contextlib .contextmanager
130
144
def open (self ) -> t .Iterator [t .IO [bytes ]]:
131
145
if (not self ._cache_dir ) or self ._disable_cache :
132
- yield io .BytesIO (self ._get_request ().content )
146
+ yield io .BytesIO (self ._get_request (response_ok = self . _validate ).content )
133
147
else :
134
148
with open (self ._download (), "rb" ) as fp :
135
149
yield fp
0 commit comments