Skip to content

Commit 27d3622

Browse files
committed
refactor: Add _http_get_download() to Downloader and rework methods
1 parent 382e256 commit 27d3622

File tree

2 files changed

+41
-86
lines changed

2 files changed

+41
-86
lines changed

src/pip/_internal/network/download.py

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
from collections.abc import Iterable
1010
from http import HTTPStatus
11-
from typing import BinaryIO
11+
from typing import BinaryIO, Mapping
1212

1313
from pip._vendor.requests.models import Response
1414
from pip._vendor.urllib3.exceptions import ReadTimeoutError
@@ -134,33 +134,9 @@ def _get_http_response_filename(resp: Response, link: Link) -> str:
134134
return filename
135135

136136

137-
def _http_get_download(
138-
session: PipSession,
139-
link: Link,
140-
range_start: int | None = 0,
141-
if_range: str | None = None,
142-
) -> Response:
143-
target_url = link.url.split("#", 1)[0]
144-
headers = HEADERS.copy()
145-
# request a partial download
146-
if range_start:
147-
headers["Range"] = f"bytes={range_start}-"
148-
# make sure the file hasn't changed
149-
if if_range:
150-
headers["If-Range"] = if_range
151-
try:
152-
resp = session.get(target_url, headers=headers, stream=True)
153-
raise_for_status(resp)
154-
except NetworkConnectionError as e:
155-
assert e.response is not None
156-
logger.critical("HTTP error %s while getting %s", e.response.status_code, link)
157-
raise
158-
return resp
159-
160-
161137
@dataclass
162138
class _FileDownload:
163-
"""Stores the state of a single file download."""
139+
"""Stores the state of a single link download."""
164140

165141
link: Link
166142
output_file: BinaryIO
@@ -175,7 +151,7 @@ def write_chunk(self, data: bytes) -> None:
175151
self.bytes_received += len(data)
176152
self.output_file.write(data)
177153

178-
def reset_download(self) -> None:
154+
def reset_file(self) -> None:
179155
"""Delete any saved data and reset progress to zero."""
180156
self.output_file.seek(0)
181157
self.output_file.truncate()
@@ -206,7 +182,7 @@ def batch(
206182

207183
def __call__(self, link: Link, location: str) -> tuple[str, str]:
208184
"""Download the file given by link into location."""
209-
resp = _http_get_download(self._session, link)
185+
resp = self._http_get(link)
210186
download_size = _get_http_response_size(resp)
211187

212188
filepath = os.path.join(location, _get_http_response_filename(resp, link))
@@ -228,12 +204,6 @@ def _process_response(self, download: _FileDownload, resp: Response) -> None:
228204
download.size,
229205
range_start=download.bytes_received,
230206
)
231-
self._write_chunks_to_file(download, chunks)
232-
233-
def _write_chunks_to_file(
234-
self, download: _FileDownload, chunks: Iterable[bytes]
235-
) -> None:
236-
"""Write the chunks to the file and return the number of bytes received."""
237207
try:
238208
for chunk in chunks:
239209
download.write_chunk(chunk)
@@ -246,7 +216,6 @@ def _write_chunks_to_file(
246216

247217
def _attempt_resume(self, download: _FileDownload, resp: Response) -> None:
248218
"""Attempt to resume the download if connection was dropped."""
249-
etag_or_last_modified = _get_http_response_etag_or_last_modified(resp)
250219

251220
while download.reattempts < self._resume_retries and download.is_incomplete():
252221
assert download.size is not None
@@ -259,22 +228,14 @@ def _attempt_resume(self, download: _FileDownload, resp: Response) -> None:
259228
)
260229

261230
try:
262-
# Try to resume the download using a HTTP range request.
263-
resume_resp = _http_get_download(
264-
self._session,
265-
download.link,
266-
range_start=download.bytes_received,
267-
if_range=etag_or_last_modified,
268-
)
269-
231+
resume_resp = self._http_get_resume(download, should_match=resp)
270232
# Fallback: if the server responded with 200 (i.e., the file has
271233
# since been modified or range requests are unsupported) or any
272234
# other unexpected status, restart the download from the beginning.
273235
must_restart = resume_resp.status_code != HTTPStatus.PARTIAL_CONTENT
274236
if must_restart:
275-
download.size, etag_or_last_modified = self._reset_download_state(
276-
download, resume_resp
277-
)
237+
download.reset_file()
238+
download.size = _get_http_response_size(resume_resp)
278239

279240
self._process_response(download, resume_resp)
280241
except (ConnectionError, ReadTimeoutError, OSError):
@@ -285,12 +246,27 @@ def _attempt_resume(self, download: _FileDownload, resp: Response) -> None:
285246
os.remove(download.output_file.name)
286247
raise IncompleteDownloadError(download)
287248

288-
def _reset_download_state(
289-
self, download: _FileDownload, resp: Response
290-
) -> tuple[int | None, str | None]:
291-
"""Reset the download state to restart downloading from the beginning."""
292-
download.reset_download()
293-
total_length = _get_http_response_size(resp)
294-
etag_or_last_modified = _get_http_response_etag_or_last_modified(resp)
295-
296-
return total_length, etag_or_last_modified
249+
def _http_get_resume(
250+
self, download: _FileDownload, should_match: Response
251+
) -> Response:
252+
"""Issue a HTTP range request to resume the download."""
253+
headers = HEADERS.copy()
254+
headers["Range"] = f"bytes={download.bytes_received}-"
255+
# If possible, use a conditional range request to avoid corrupted
256+
# downloads caused by the remote file changing in-between.
257+
if identifier := _get_http_response_etag_or_last_modified(should_match):
258+
headers["If-Range"] = identifier
259+
return self._http_get(download.link, headers)
260+
261+
def _http_get(self, link: Link, headers: Mapping[str, str] = HEADERS) -> Response:
262+
target_url = link.url_without_fragment
263+
try:
264+
resp = self._session.get(target_url, headers=headers, stream=True)
265+
raise_for_status(resp)
266+
except NetworkConnectionError as e:
267+
assert e.response is not None
268+
logger.critical(
269+
"HTTP error %s while getting %s", e.response.status_code, link
270+
)
271+
raise
272+
return resp

tests/unit/test_network_download.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pip._internal.network.download import (
1313
Downloader,
1414
_get_http_response_size,
15-
_http_get_download,
1615
_prepare_download,
1716
parse_content_disposition,
1817
sanitize_content_filename,
@@ -149,29 +148,6 @@ def test_sanitize_content_filename__platform_dependent(
149148
assert sanitize_content_filename(filename) == expected
150149

151150

152-
@pytest.mark.parametrize(
153-
"range_start, if_range, expected_headers",
154-
[
155-
(None, None, HEADERS),
156-
(1234, None, {**HEADERS, "Range": "bytes=1234-"}),
157-
(1234, '"etag"', {**HEADERS, "Range": "bytes=1234-", "If-Range": '"etag"'}),
158-
],
159-
)
160-
def test_http_get_download(
161-
range_start: int | None,
162-
if_range: str | None,
163-
expected_headers: dict[str, str],
164-
) -> None:
165-
session = PipSession()
166-
session.get = MagicMock()
167-
link = Link("http://example.com/foo.tgz")
168-
with patch("pip._internal.network.download.raise_for_status"):
169-
_http_get_download(session, link, range_start, if_range)
170-
session.get.assert_called_once_with(
171-
"http://example.com/foo.tgz", headers=expected_headers, stream=True
172-
)
173-
174-
175151
@pytest.mark.parametrize(
176152
"content_disposition, default_filename, expected",
177153
[
@@ -323,7 +299,7 @@ def test_downloader(
323299
resume_retries: int,
324300
mock_responses: list[tuple[dict[str, str], int, bytes]],
325301
# list of (range_start, if_range)
326-
expected_resume_args: list[tuple[int | None, int | None]],
302+
expected_resume_args: list[tuple[int | None, str | None]],
327303
# expected_bytes is None means the download should fail
328304
expected_bytes: bytes | None,
329305
tmpdir: Path,
@@ -338,9 +314,9 @@ def test_downloader(
338314
resp.headers = headers
339315
resp.status_code = status_code
340316
responses.append(resp)
341-
_http_get_download = MagicMock(side_effect=responses)
317+
_http_get_mock = MagicMock(side_effect=responses)
342318

343-
with patch("pip._internal.network.download._http_get_download", _http_get_download):
319+
with patch.object(Downloader, "_http_get", _http_get_mock):
344320
if expected_bytes is None:
345321
remove = MagicMock(return_value=None)
346322
with patch("os.remove", remove):
@@ -354,9 +330,12 @@ def test_downloader(
354330
downloaded_bytes = downloaded_file.read()
355331
assert downloaded_bytes == expected_bytes
356332

357-
calls = [call(session, link)] # the initial request
333+
calls = [call(link)] # the initial GET request
358334
for range_start, if_range in expected_resume_args:
359-
calls.append(call(session, link, range_start=range_start, if_range=if_range))
335+
headers = {**HEADERS, "Range": f"bytes={range_start}-"}
336+
if if_range:
337+
headers["If-Range"] = if_range
338+
calls.append(call(link, headers))
360339

361-
# Make sure that the download makes additional requests for resumption
362-
_http_get_download.assert_has_calls(calls)
340+
# Make sure that the downloader makes additional requests for resumption
341+
_http_get_mock.assert_has_calls(calls)

0 commit comments

Comments
 (0)