diff --git a/httpie/downloads.py b/httpie/downloads.py index 9c4b895e6f..205379e102 100644 --- a/httpie/downloads.py +++ b/httpie/downloads.py @@ -2,6 +2,7 @@ Download mode implementation. """ + import mimetypes import os import re @@ -12,10 +13,9 @@ import requests +from .context import Environment from .models import HTTPResponse, OutputOptions from .output.streams import RawStream -from .context import Environment - PARTIAL_CONTENT = 206 @@ -37,24 +37,23 @@ def parse_content_range(content_range: str, resumed_from: int) -> int: """ if content_range is None: - raise ContentRangeError('Missing Content-Range') + raise ContentRangeError("Missing Content-Range") pattern = ( - r'^bytes (?P\d+)-(?P\d+)' - r'/(\*|(?P\d+))$' + r"^bytes (?P\d+)-(?P\d+)" + r"/(\*|(?P\d+))$" ) match = re.match(pattern, content_range) if not match: - raise ContentRangeError( - f'Invalid Content-Range format {content_range!r}') + raise ContentRangeError(f"Invalid Content-Range format {content_range!r}") content_range_dict = match.groupdict() - first_byte_pos = int(content_range_dict['first_byte_pos']) - last_byte_pos = int(content_range_dict['last_byte_pos']) + first_byte_pos = int(content_range_dict["first_byte_pos"]) + last_byte_pos = int(content_range_dict["last_byte_pos"]) instance_length = ( - int(content_range_dict['instance_length']) - if content_range_dict['instance_length'] + int(content_range_dict["instance_length"]) + if content_range_dict["instance_length"] else None ) @@ -64,27 +63,24 @@ def parse_content_range(content_range: str, resumed_from: int) -> int: # last-byte-pos value, is invalid. The recipient of an invalid # byte-content-range- spec MUST ignore it and any content # transferred along with it." - if (first_byte_pos > last_byte_pos - or (instance_length is not None - and instance_length <= last_byte_pos)): - raise ContentRangeError( - f'Invalid Content-Range returned: {content_range!r}') + if first_byte_pos > last_byte_pos or ( + instance_length is not None and instance_length <= last_byte_pos + ): + raise ContentRangeError(f"Invalid Content-Range returned: {content_range!r}") - if (first_byte_pos != resumed_from - or (instance_length is not None - and last_byte_pos + 1 != instance_length)): + if first_byte_pos != resumed_from or ( + instance_length is not None and last_byte_pos + 1 != instance_length + ): # Not what we asked for. raise ContentRangeError( - f'Unexpected Content-Range returned ({content_range!r})' + f"Unexpected Content-Range returned ({content_range!r})" f' for the requested Range ("bytes={resumed_from}-")' ) return last_byte_pos + 1 -def filename_from_content_disposition( - content_disposition: str -) -> Optional[str]: +def filename_from_content_disposition(content_disposition: str) -> Optional[str]: """ Extract and validate filename from a Content-Disposition header. @@ -94,28 +90,28 @@ def filename_from_content_disposition( """ # attachment; filename=jakubroztocil-httpie-0.4.1-20-g40bd8f6.tar.gz - msg = Message(f'Content-Disposition: {content_disposition}') + msg = Message(f"Content-Disposition: {content_disposition}") filename = msg.get_filename() if filename: # Basic sanitation. - filename = os.path.basename(filename).lstrip('.').strip() + filename = os.path.basename(filename).lstrip(".").strip() if filename: return filename def filename_from_url(url: str, content_type: Optional[str]) -> str: - fn = urlsplit(url).path.rstrip('/') - fn = os.path.basename(fn) if fn else 'index' - if '.' not in fn and content_type: - content_type = content_type.split(';')[0] - if content_type == 'text/plain': + fn = urlsplit(url).path.rstrip("/") + fn = os.path.basename(fn) if fn else "index" + if "." not in fn and content_type: + content_type = content_type.split(";")[0] + if content_type == "text/plain": # mimetypes returns '.ksh' - ext = '.txt' + ext = ".txt" else: ext = mimetypes.guess_extension(content_type) - if ext == '.htm': - ext = '.html' + if ext == ".htm": + ext = ".html" if ext: fn += ext @@ -136,12 +132,12 @@ def trim_filename(filename: str, max_len: int) -> str: def get_filename_max_length(directory: str) -> int: max_len = 255 - if hasattr(os, 'pathconf') and 'PC_NAME_MAX' in os.pathconf_names: - max_len = os.pathconf(directory, 'PC_NAME_MAX') + if hasattr(os, "pathconf") and "PC_NAME_MAX" in os.pathconf_names: + max_len = os.pathconf(directory, "PC_NAME_MAX") return max_len -def trim_filename_if_needed(filename: str, directory='.', extra=0) -> str: +def trim_filename_if_needed(filename: str, directory=".", extra=0) -> str: max_len = get_filename_max_length(directory) - extra if len(filename) > max_len: filename = trim_filename(filename, max_len) @@ -151,7 +147,7 @@ def trim_filename_if_needed(filename: str, directory='.', extra=0) -> str: def get_unique_filename(filename: str, exists=os.path.exists) -> str: attempt = 0 while True: - suffix = f'-{attempt}' if attempt > 0 else '' + suffix = f"-{attempt}" if attempt > 0 else "" try_filename = trim_filename_if_needed(filename, extra=len(suffix)) try_filename += suffix if not exists(try_filename): @@ -161,12 +157,7 @@ def get_unique_filename(filename: str, exists=os.path.exists) -> str: class Downloader: - def __init__( - self, - env: Environment, - output_file: IO = None, - resume: bool = False - ): + def __init__(self, env: Environment, output_file: IO = None, resume: bool = False): """ :param resume: Should the download resume if partial download already exists. @@ -190,19 +181,17 @@ def pre_request(self, request_headers: dict): """ # Ask the server not to encode the content so that we can resume, etc. - request_headers['Accept-Encoding'] = 'identity' + request_headers["Accept-Encoding"] = "identity" if self._resume: bytes_have = os.path.getsize(self._output_file.name) if bytes_have: # Set ``Range`` header to resume the download # TODO: Use "If-Range: mtime" to make sure it's fresh? - request_headers['Range'] = f'bytes={bytes_have}-' + request_headers["Range"] = f"bytes={bytes_have}-" self._resumed_from = bytes_have def start( - self, - initial_url: str, - final_response: requests.Response + self, initial_url: str, final_response: requests.Response ) -> Tuple[RawStream, IO]: """ Initiate and return a stream for `response` body with progress @@ -216,13 +205,27 @@ def start( """ assert not self.status.time_started - # FIXME: some servers still might sent Content-Encoding: gzip - # + # Some servers may still send a compressed body even though + # we ask for identity encoding. In that case, ``Content-Length`` + # refers to the encoded size (RFC 9110 ยง 8.6), so we disable + # automatic decoding to make our byte tracking match. try: - total_size = int(final_response.headers['Content-Length']) + total_size = int(final_response.headers["Content-Length"]) except (KeyError, ValueError, TypeError): total_size = None + content_encoding = final_response.headers.get("Content-Encoding") + if content_encoding: + final_response.raw.decode_content = False + + class EncodedHTTPResponse(HTTPResponse): + def iter_body(self, chunk_size=1): # type: ignore[override] + return final_response.raw.stream(chunk_size, decode_content=False) + + response_msg = EncodedHTTPResponse(final_response) + else: + response_msg = HTTPResponse(final_response) + if not self._output_file: self._output_file = self._get_output_file_from_response( initial_url=initial_url, @@ -232,8 +235,7 @@ def start( # `--output, -o` provided if self._resume and final_response.status_code == PARTIAL_CONTENT: total_size = parse_content_range( - final_response.headers.get('Content-Range'), - self._resumed_from + final_response.headers.get("Content-Range"), self._resumed_from ) else: @@ -244,9 +246,11 @@ def start( except OSError: pass # stdout - output_options = OutputOptions.from_message(final_response, headers=False, body=True) + output_options = OutputOptions.from_message( + final_response, headers=False, body=True + ) stream = RawStream( - msg=HTTPResponse(final_response), + msg=response_msg, output_options=output_options, on_body_chunk_downloaded=self.chunk_downloaded, ) @@ -254,7 +258,7 @@ def start( self.status.started( output_file=self._output_file, resumed_from=self._resumed_from, - total_size=total_size + total_size=total_size, ) return stream, self._output_file @@ -292,16 +296,17 @@ def _get_output_file_from_response( ) -> IO: # Output file not specified. Pick a name that doesn't exist yet. filename = None - if 'Content-Disposition' in final_response.headers: + if "Content-Disposition" in final_response.headers: filename = filename_from_content_disposition( - final_response.headers['Content-Disposition']) + final_response.headers["Content-Disposition"] + ) if not filename: filename = filename_from_url( url=initial_url, - content_type=final_response.headers.get('Content-Type'), + content_type=final_response.headers.get("Content-Type"), ) unique_filename = get_unique_filename(filename) - return open(unique_filename, buffering=0, mode='a+b') + return open(unique_filename, buffering=0, mode="a+b") class DownloadStatus: @@ -325,11 +330,11 @@ def started(self, output_file, resumed_from=0, total_size=None): def start_display(self, output_file): from httpie.output.ui.rich_progress import ( DummyDisplay, + ProgressDisplay, StatusDisplay, - ProgressDisplay ) - message = f'Downloading to {output_file.name}' + message = f"Downloading to {output_file.name}" if self.env.show_displays: if self.total_size is None: # Rich does not support progress bars without a total @@ -341,9 +346,7 @@ def start_display(self, output_file): self.display = DummyDisplay(self.env) self.display.start( - total=self.total_size, - at=self.downloaded, - description=message + total=self.total_size, at=self.downloaded, description=message ) def chunk_downloaded(self, size): @@ -357,10 +360,7 @@ def has_finished(self): @property def time_spent(self): - if ( - self.time_started is not None - and self.time_finished is not None - ): + if self.time_started is not None and self.time_finished is not None: return self.time_finished - self.time_started else: return None @@ -369,9 +369,9 @@ def finished(self): assert self.time_started is not None assert self.time_finished is None self.time_finished = monotonic() - if hasattr(self, 'display'): + if hasattr(self, "display"): self.display.stop(self.time_spent) def terminate(self): - if hasattr(self, 'display'): + if hasattr(self, "display"): self.display.stop(self.time_spent) diff --git a/tests/test_download_gzip_regression.py b/tests/test_download_gzip_regression.py new file mode 100644 index 0000000000..2223312c88 --- /dev/null +++ b/tests/test_download_gzip_regression.py @@ -0,0 +1,80 @@ +# tests/test_download_gzip_regression.py +import gzip +import io +import shutil +import subprocess +import sys +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer + + +RAW = b"A" * 50000 +buf = io.BytesIO() +with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(RAW) +GZ = buf.getvalue() + + +class Handler(BaseHTTPRequestHandler): + def _send_common_headers(self, content_length: int) -> None: + self.send_response(200) + self.send_header("Content-Type", "application/octet-stream") + self.send_header("Content-Encoding", "gzip") + self.send_header("Content-Length", str(content_length)) + self.end_headers() + + def do_GET(self) -> None: + if self.path != "/file.gz": + self.send_error(404) + return + self._send_common_headers(len(GZ)) + self.wfile.write(GZ) + + def do_HEAD(self) -> None: + if self.path != "/file.gz": + self.send_error(404) + return + self._send_common_headers(len(GZ)) + + def do_POST(self) -> None: + if self.path != "/file.gz": + self.send_error(404) + return + self._send_common_headers(len(GZ)) + self.wfile.write(GZ) + + def log_message(self, *args) -> None: + pass + + +def start_server(): + srv = HTTPServer(("127.0.0.1", 0), Handler) + th = threading.Thread(target=srv.serve_forever, daemon=True) + th.start() + return srv, th + + +def test_download_gzip_content_length(tmp_path): + srv, _ = start_server() + try: + url = f"http://127.0.0.1:{srv.server_port}/file.gz" + out = tmp_path / "out.gz" + + exe = shutil.which("http") + if exe: + cmd = [exe, "GET", url, "--download", "--output", str(out)] + else: + cmd = [sys.executable, "-m", "httpie", "GET", url, "--download", "--output", str(out)] + + p = subprocess.run(cmd, capture_output=True, text=True) + + assert p.returncode == 0, f"stderr:\n{p.stderr}\nstdout:\n{p.stdout}" + assert out.exists(), "output file not created" + size = out.stat().st_size + assert size == len(GZ), f"saved size {size} != expected {len(GZ)}" + + combined = (p.stdout or "") + (p.stderr or "") + assert ">100%" not in combined + assert "Incomplete download" not in combined + finally: + srv.shutdown()