diff --git a/aiohttp/client.py b/aiohttp/client.py index 2b1ccb8ee03..c3f8a1286c4 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -676,7 +676,7 @@ async def _connect_and_send_request( max_field_size=max_field_size, ) try: - resp = await req.send(conn) + resp = await req._send(conn) try: await resp.start(conn) except BaseException: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 384087cd8b3..7fafc54f621 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -7,8 +7,9 @@ import sys import traceback import warnings +from collections.abc import Sequence from hashlib import md5, sha1, sha256 -from http.cookies import CookieError, Morsel, SimpleCookie +from http.cookies import BaseCookie, CookieError, Morsel, SimpleCookie from types import MappingProxyType, TracebackType from typing import ( TYPE_CHECKING, @@ -28,8 +29,9 @@ from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL -from . import hdrs, helpers, http, multipart, payload +from . import hdrs, multipart, payload from .abc import AbstractStreamWriter +from .base_protocol import BaseProtocol from .client_exceptions import ( ClientConnectionError, ClientOSError, @@ -40,7 +42,6 @@ ) from .compression_utils import HAS_BROTLI from .formdata import FormData -from .hdrs import CONTENT_TYPE from .helpers import ( _SENTINEL, BaseTimerContext, @@ -58,6 +59,7 @@ ) from .http import ( SERVER_SOFTWARE, + HttpProcessingError, HttpVersion, HttpVersion10, HttpVersion11, @@ -172,6 +174,7 @@ def check(self, transport: asyncio.Transport) -> None: SSL_ALLOWED_TYPES = (bool,) # type: ignore[unreachable] +_CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") _SSL_SCHEMES = frozenset(("https", "wss")) @@ -190,623 +193,352 @@ class ConnectionKey(NamedTuple): proxy_headers_hash: Optional[int] # hash(CIMultiDict) -class ClientRequest: - GET_METHODS = { - hdrs.METH_GET, - hdrs.METH_HEAD, - hdrs.METH_OPTIONS, - hdrs.METH_TRACE, - } - POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} - ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE}) - - DEFAULT_HEADERS = { - hdrs.ACCEPT: "*/*", - hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), - } - - # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. - body: Any = b"" - auth = None - response = None +class ClientResponse(HeadersMixin): + # Some of these attributes are None when created, + # but will be set by the start() method. + # As the end user will likely never see the None values, we cheat the types below. + # from the Status-Line of the response + version: Optional[HttpVersion] = None # HTTP-Version + status: int = None # type: ignore[assignment] # Status-Code + reason: Optional[str] = None # Reason-Phrase - # These class defaults help create_autospec() work correctly. - # If autospec is improved in future, maybe these can be removed. - url = URL() - method = "GET" + content: StreamReader = None # type: ignore[assignment] # Payload stream + _body: Optional[bytes] = None + _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] + _history: Tuple["ClientResponse", ...] = () + _raw_headers: RawHeaders = None # type: ignore[assignment] - __writer: Optional["asyncio.Task[None]"] = None # async task for streaming data - _continue = None # waiter future for '100 Continue' response + _connection: Optional["Connection"] = None # current connection + _cookies: Optional[SimpleCookie] = None + _continue: Optional["asyncio.Future[bool]"] = None + _source_traceback: Optional[traceback.StackSummary] = None + _session: Optional["ClientSession"] = None + # set up by ClientRequest after ClientResponse object creation + # post-init stage allows to not change ctor signature + _closed = True # to allow __del__ for non-initialized properly response + _released = False + _in_context = False - _skip_auto_headers: Optional["CIMultiDict[None]"] = None + _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" - # N.B. - # Adding __del__ method with self._writer closing doesn't make sense - # because _writer is instance method, thus it keeps a reference to self. - # Until writer has finished finalizer will not be called. + __writer: Optional["asyncio.Task[None]"] = None def __init__( self, method: str, url: URL, *, - params: Query = None, - headers: Optional[LooseHeaders] = None, - skip_auto_headers: Optional[Iterable[str]] = None, - data: Any = None, - cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, - version: http.HttpVersion = http.HttpVersion11, - compress: Union[str, bool] = False, - chunked: Optional[bool] = None, - expect100: bool = False, + writer: "Optional[asyncio.Task[None]]", + continue100: Optional["asyncio.Future[bool]"], + timer: Optional[BaseTimerContext], + request_info: RequestInfo, + traces: Sequence["Trace"], loop: asyncio.AbstractEventLoop, - response_class: Optional[Type["ClientResponse"]] = None, - proxy: Optional[URL] = None, - proxy_auth: Optional[BasicAuth] = None, - timer: Optional[BaseTimerContext] = None, - session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, bool, Fingerprint] = True, - proxy_headers: Optional[LooseHeaders] = None, - traces: Optional[List["Trace"]] = None, - trust_env: bool = False, - server_hostname: Optional[str] = None, - ): - if match := _CONTAINS_CONTROL_CHAR_RE.search(method): - raise ValueError( - f"Method cannot contain non-token characters {method!r} " - f"(found at least {match.group()!r})" - ) + session: Optional["ClientSession"], + ) -> None: # URL forbids subclasses, so a simple type check is enough. - assert type(url) is URL, url - if proxy is not None: - assert type(proxy) is URL, proxy - # FIXME: session is None in tests only, need to fix tests - # assert session is not None - if TYPE_CHECKING: - assert session is not None - self._session = session - if params: - url = url.extend_query(params) - self.original_url = url - self.url = url.with_fragment(None) if url.raw_fragment else url - self.method = method.upper() - self.chunked = chunked - self.loop = loop - self.length = None - if response_class is None: - real_response_class = ClientResponse - else: - real_response_class = response_class - self.response_class: Type[ClientResponse] = real_response_class - self._timer = timer if timer is not None else TimerNoop() - self._ssl = ssl - self.server_hostname = server_hostname + assert type(url) is URL + + self.method = method + self._real_url = url + self._url = url.with_fragment(None) if url.raw_fragment else url + if writer is not None: + self._writer = writer + if continue100 is not None: + self._continue = continue100 + self._request_info = request_info + self._timer = timer if timer is not None else TimerNoop() + self._cache: Dict[str, Any] = {} + self._traces = traces + self._loop = loop + # Save reference to _resolve_charset, so that get_encoding() will still + # work after the response has finished reading the body. + if session is not None: + # store a reference to session #1985 + self._session = session + self._resolve_charset = session._resolve_charset if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - self.update_version(version) - self.update_host(url) - self.update_headers(headers) - self.update_auto_headers(skip_auto_headers) - self.update_cookies(cookies) - self.update_content_encoding(data, compress) - self.update_auth(auth, trust_env) - self.update_proxy(proxy, proxy_auth, proxy_headers) - - self.update_body_from_data(data) - if data is not None or self.method not in self.GET_METHODS: - self.update_transfer_encoding() - self.update_expect_continue(expect100) - self._traces = [] if traces is None else traces - def __reset_writer(self, _: object = None) -> None: self.__writer = None - def _get_content_length(self) -> Optional[int]: - """Extract and validate Content-Length header value. - - Returns parsed Content-Length value or None if not set. - Raises ValueError if header exists but cannot be parsed as an integer. - """ - if hdrs.CONTENT_LENGTH not in self.headers: - return None - - content_length_hdr = self.headers[hdrs.CONTENT_LENGTH] - try: - return int(content_length_hdr) - except ValueError: - raise ValueError( - f"Invalid Content-Length header: {content_length_hdr}" - ) from None - - @property - def skip_auto_headers(self) -> CIMultiDict[None]: - return self._skip_auto_headers or CIMultiDict() - @property def _writer(self) -> Optional["asyncio.Task[None]"]: + """The writer task for streaming data. + + _writer is only provided for backwards compatibility + for subclasses that may need to access it. + """ return self.__writer @_writer.setter - def _writer(self, writer: "asyncio.Task[None]") -> None: + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + """Set the writer task for streaming data.""" if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - writer.add_done_callback(self.__reset_writer) - - def is_ssl(self) -> bool: - return self.url.scheme in _SSL_SCHEMES - - @property - def ssl(self) -> Union["SSLContext", bool, Fingerprint]: - return self._ssl - - @property - def connection_key(self) -> ConnectionKey: # type: ignore[misc] - if proxy_headers := self.proxy_headers: - h: Optional[int] = hash(tuple(proxy_headers.items())) + if writer is None: + return + if writer.done(): + # The writer is already done, so we can clear it immediately. + self.__writer = None else: - h = None - url = self.url - return tuple.__new__( - ConnectionKey, - ( - url.raw_host or "", - url.port, - url.scheme in _SSL_SCHEMES, - self._ssl, - self.proxy, - self.proxy_auth, - h, - ), - ) - - @property - def host(self) -> str: - ret = self.url.raw_host - assert ret is not None - return ret - - @property - def port(self) -> Optional[int]: - return self.url.port - - @property - def request_info(self) -> RequestInfo: - headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) - # These are created on every request, so we use a NamedTuple - # for performance reasons. We don't use the RequestInfo.__new__ - # method because it has a different signature which is provided - # for backwards compatibility only. - return tuple.__new__( - RequestInfo, (self.url, self.method, headers, self.original_url) - ) + writer.add_done_callback(self.__reset_writer) @property - def session(self) -> "ClientSession": - """Return the ClientSession instance. + def cookies(self) -> SimpleCookie: + if self._cookies is None: + self._cookies = SimpleCookie() + return self._cookies - This property provides access to the ClientSession that initiated - this request, allowing middleware to make additional requests - using the same session. - """ - return self._session + @cookies.setter + def cookies(self, cookies: SimpleCookie) -> None: + self._cookies = cookies - def update_host(self, url: URL) -> None: - """Update destination host, port and connection type (ssl).""" - # get host/port - if not url.raw_host: - raise InvalidURL(url) + @reify + def url(self) -> URL: + return self._url - # basic auth info - if url.raw_user or url.raw_password: - self.auth = helpers.BasicAuth(url.user or "", url.password or "") + @reify + def real_url(self) -> URL: + return self._real_url - def update_version(self, version: Union[http.HttpVersion, str]) -> None: - """Convert request version to two elements tuple. + @reify + def host(self) -> str: + assert self._url.host is not None + return self._url.host - parser HTTP version '1.1' => (1, 1) - """ - if isinstance(version, str): - v = [part.strip() for part in version.split(".", 1)] - try: - version = http.HttpVersion(int(v[0]), int(v[1])) - except ValueError: - raise ValueError( - f"Can not parse http version number: {version}" - ) from None - self.version = version + @reify + def headers(self) -> "CIMultiDictProxy[str]": + return self._headers - def update_headers(self, headers: Optional[LooseHeaders]) -> None: - """Update request headers.""" - self.headers: CIMultiDict[str] = CIMultiDict() + @reify + def raw_headers(self) -> RawHeaders: + return self._raw_headers - # Build the host header - host = self.url.host_port_subcomponent + @reify + def request_info(self) -> RequestInfo: + return self._request_info - # host_port_subcomponent is None when the URL is a relative URL. - # but we know we do not have a relative URL here. - assert host is not None - self.headers[hdrs.HOST] = host + @reify + def content_disposition(self) -> Optional[ContentDisposition]: + raw = self._headers.get(hdrs.CONTENT_DISPOSITION) + if raw is None: + return None + disposition_type, params_dct = multipart.parse_content_disposition(raw) + params = MappingProxyType(params_dct) + filename = multipart.content_disposition_filename(params) + return ContentDisposition(disposition_type, params, filename) - if not headers: + def __del__(self, _warnings: Any = warnings) -> None: + if self._closed: return - if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() + if self._connection is not None: + self._connection.release() + self._cleanup_writer() - for key, value in headers: # type: ignore[misc] - # A special case for Host header - if key in hdrs.HOST_ALL: - self.headers[key] = value - else: - self.headers.add(key, value) + if self._loop.get_debug(): + _warnings.warn( + f"Unclosed response {self!r}", ResourceWarning, source=self + ) + context = {"client_response": self, "message": "Unclosed response"} + if self._source_traceback: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) - def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None: - if skip_auto_headers is not None: - self._skip_auto_headers = CIMultiDict( - (hdr, None) for hdr in sorted(skip_auto_headers) - ) - used_headers = self.headers.copy() - used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type] + def __repr__(self) -> str: + out = io.StringIO() + ascii_encodable_url = str(self.url) + if self.reason: + ascii_encodable_reason = self.reason.encode( + "ascii", "backslashreplace" + ).decode("ascii") else: - # Fast path when there are no headers to skip - # which is the most common case. - used_headers = self.headers + ascii_encodable_reason = "None" + print( + "".format( + ascii_encodable_url, self.status, ascii_encodable_reason + ), + file=out, + ) + print(self.headers, file=out) + return out.getvalue() - for hdr, val in self.DEFAULT_HEADERS.items(): - if hdr not in used_headers: - self.headers[hdr] = val + @property + def connection(self) -> Optional["Connection"]: + return self._connection - if hdrs.USER_AGENT not in used_headers: - self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE + @reify + def history(self) -> Tuple["ClientResponse", ...]: + """A sequence of responses, if redirects occurred.""" + return self._history - def update_cookies(self, cookies: Optional[LooseCookies]) -> None: - """Update request cookies header.""" - if not cookies: - return + @reify + def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": + links_str = ", ".join(self.headers.getall("link", [])) - c = SimpleCookie() - if hdrs.COOKIE in self.headers: - c.load(self.headers.get(hdrs.COOKIE, "")) - del self.headers[hdrs.COOKIE] + if not links_str: + return MultiDictProxy(MultiDict()) - if isinstance(cookies, Mapping): - iter_cookies = cookies.items() - else: - iter_cookies = cookies # type: ignore[assignment] - for name, value in iter_cookies: - if isinstance(value, Morsel): - # Preserve coded_value - mrsl_val = value.get(value.key, Morsel()) - mrsl_val.set(value.key, value.value, value.coded_value) - c[name] = mrsl_val - else: - c[name] = value # type: ignore[assignment] + links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict() - self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() + for val in re.split(r",(?=\s*<)", links_str): + match = re.match(r"\s*<(.*)>(.*)", val) + if match is None: # Malformed link + continue + url, params_str = match.groups() + params = params_str.split(";")[1:] - def update_content_encoding(self, data: Any, compress: Union[bool, str]) -> None: - """Set request content encoding.""" - self.compress = None - if not data: - return + link: MultiDict[Union[str, URL]] = MultiDict() - if self.headers.get(hdrs.CONTENT_ENCODING): - if compress: - raise ValueError( - "compress can not be set if Content-Encoding header is set" - ) - elif compress: - self.compress = compress if isinstance(compress, str) else "deflate" - self.headers[hdrs.CONTENT_ENCODING] = self.compress - self.chunked = True # enable chunked, no need to deal with length + for param in params: + match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) + if match is None: # Malformed param + continue + key, _, value, _ = match.groups() - def update_transfer_encoding(self) -> None: - """Analyze transfer-encoding header.""" - te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() + link.add(key, value) - if "chunked" in te: - if self.chunked: - raise ValueError( - "chunked can not be set " - 'if "Transfer-Encoding: chunked" header is set' - ) + key = link.get("rel", url) - elif self.chunked: - if hdrs.CONTENT_LENGTH in self.headers: - raise ValueError( - "chunked can not be set if Content-Length header is set" - ) + link.add("url", self.url.join(URL(url))) - self.headers[hdrs.TRANSFER_ENCODING] = "chunked" - else: - if hdrs.CONTENT_LENGTH not in self.headers: - self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) + links.add(str(key), MultiDictProxy(link)) - def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: - """Set basic auth.""" - if auth is None: - auth = self.auth - if auth is None and trust_env and self.url.host is not None: - netrc_obj = netrc_from_env() - with contextlib.suppress(LookupError): - auth = basicauth_from_netrc(netrc_obj, self.url.host) - if auth is None: - return + return MultiDictProxy(links) - if not isinstance(auth, helpers.BasicAuth): - raise TypeError("BasicAuth() tuple is required instead") + async def start(self, connection: "Connection") -> "ClientResponse": + """Start response processing.""" + self._closed = False + self._protocol = connection.protocol + self._connection = connection - self.headers[hdrs.AUTHORIZATION] = auth.encode() + with self._timer: + while True: + # read response + try: + protocol = self._protocol + message, payload = await protocol.read() # type: ignore[union-attr] + except HttpProcessingError as exc: + raise ClientResponseError( + self.request_info, + self.history, + status=exc.code, + message=exc.message, + headers=exc.headers, + ) from exc - def update_body_from_data(self, body: Any) -> None: - if body is None: - return + if message.code < 100 or message.code > 199 or message.code == 101: + break - # FormData - if isinstance(body, FormData): - body = body() + if self._continue is not None: + set_result(self._continue, True) + self._continue = None - try: - body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) - except payload.LookupError: - boundary = None - if CONTENT_TYPE in self.headers: - boundary = parse_mimetype(self.headers[CONTENT_TYPE]).parameters.get( - "boundary" - ) - body = FormData(body, boundary=boundary)() + # payload eof handler + payload.on_eof(self._response_eof) - self.body = body + # response status + self.version = message.version + self.status = message.code + self.reason = message.reason - # enable chunked encoding if needed - if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: - if (size := body.size) is not None: - self.headers[hdrs.CONTENT_LENGTH] = str(size) - else: - self.chunked = True + # headers + self._headers = message.headers # type is CIMultiDictProxy + self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] - # copy payload headers - assert body.headers - headers = self.headers - skip_headers = self._skip_auto_headers - for key, value in body.headers.items(): - if key in headers or (skip_headers is not None and key in skip_headers): - continue - headers[key] = value + # payload + self.content = payload - def update_expect_continue(self, expect: bool = False) -> None: - if expect: - self.headers[hdrs.EXPECT] = "100-continue" - elif ( - hdrs.EXPECT in self.headers - and self.headers[hdrs.EXPECT].lower() == "100-continue" - ): - expect = True + # cookies + if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()): + cookies = SimpleCookie() + for hdr in cookie_hdrs: + try: + cookies.load(hdr) + except CookieError as exc: + client_logger.warning("Can not load response cookies: %s", exc) + self._cookies = cookies + return self - if expect: - self._continue = self.loop.create_future() + def _response_eof(self) -> None: + if self._closed: + return - def update_proxy( - self, - proxy: Optional[URL], - proxy_auth: Optional[BasicAuth], - proxy_headers: Optional[LooseHeaders], - ) -> None: - self.proxy = proxy - if proxy is None: - self.proxy_auth = None - self.proxy_headers = None + # protocol could be None because connection could be detached + protocol = self._connection and self._connection.protocol + if protocol is not None and protocol.upgraded: return - if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): - raise ValueError("proxy_auth must be None or BasicAuth() tuple") - self.proxy_auth = proxy_auth + self._closed = True + self._cleanup_writer() + self._release_connection() - if proxy_headers is not None and not isinstance( - proxy_headers, (MultiDict, MultiDictProxy) - ): - proxy_headers = CIMultiDict(proxy_headers) - self.proxy_headers = proxy_headers + @property + def closed(self) -> bool: + return self._closed - async def write_bytes( - self, - writer: AbstractStreamWriter, - conn: "Connection", - content_length: Optional[int], - ) -> None: - """ - Write the request body to the connection stream. + def close(self) -> None: + if not self._released: + self._notify_content() - This method handles writing different types of request bodies: - 1. Payload objects (using their specialized write_with_length method) - 2. Bytes/bytearray objects - 3. Iterable body content + self._closed = True + if self._loop.is_closed(): + return - Args: - writer: The stream writer to write the body to - conn: The connection being used for this request - content_length: Optional maximum number of bytes to write from the body - (None means write the entire body) + self._cleanup_writer() + if self._connection is not None: + self._connection.close() + self._connection = None - The method properly handles: - - Waiting for 100-Continue responses if required - - Content length constraints for chunked encoding - - Error handling for network issues, cancellation, and other exceptions - - Signaling EOF and timeout management + def release(self) -> None: + if not self._released: + self._notify_content() - Raises: - ClientOSError: When there's an OS-level error writing the body - ClientConnectionError: When there's a general connection error - asyncio.CancelledError: When the operation is cancelled + self._closed = True - """ - # 100 response - if self._continue is not None: - # Force headers to be sent before waiting for 100-continue - writer.send_headers() - await writer.drain() - await self._continue + self._cleanup_writer() + self._release_connection() - protocol = conn.protocol - assert protocol is not None - try: - if isinstance(self.body, payload.Payload): - # Specialized handling for Payload objects that know how to write themselves - await self.body.write_with_length(writer, content_length) - else: - # Handle bytes/bytearray by converting to an iterable for consistent handling - if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) - - if content_length is None: - # Write the entire body without length constraint - for chunk in self.body: - await writer.write(chunk) - else: - # Write with length constraint, respecting content_length limit - # If the body is larger than content_length, we truncate it - remaining_bytes = content_length - for chunk in self.body: - await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) - if remaining_bytes <= 0: - break - except OSError as underlying_exc: - reraised_exc = underlying_exc - - # Distinguish between timeout and other OS errors for better error reporting - exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( - underlying_exc, asyncio.TimeoutError - ) - if exc_is_not_timeout: - reraised_exc = ClientOSError( - underlying_exc.errno, - f"Can not write request body for {self.url !s}", - ) - - set_exception(protocol, reraised_exc, underlying_exc) - except asyncio.CancelledError: - # Body hasn't been fully sent, so connection can't be reused - conn.close() - raise - except Exception as underlying_exc: - set_exception( - protocol, - ClientConnectionError( - "Failed to send bytes into the underlying connection " - f"{conn !s}: {underlying_exc!r}", - ), - underlying_exc, - ) - else: - # Successfully wrote the body, signal EOF and start response timeout - await writer.write_eof() - protocol.start_timeout() - - async def send(self, conn: "Connection") -> "ClientResponse": - # Specify request target: - # - CONNECT request must send authority form URI - # - not CONNECT proxy must send absolute form URI - # - most common is origin form URI - if self.method == hdrs.METH_CONNECT: - connect_host = self.url.host_subcomponent - assert connect_host is not None - path = f"{connect_host}:{self.url.port}" - elif self.proxy and not self.is_ssl(): - path = str(self.url) - else: - path = self.url.raw_path_qs + @property + def ok(self) -> bool: + """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. - protocol = conn.protocol - assert protocol is not None - writer = StreamWriter( - protocol, - self.loop, - on_chunk_sent=( - functools.partial(self._on_chunk_request_sent, self.method, self.url) - if self._traces - else None - ), - on_headers_sent=( - functools.partial(self._on_headers_request_sent, self.method, self.url) - if self._traces - else None - ), - ) + This is **not** a check for ``200 OK`` but a check that the response + status is under 400. + """ + return 400 > self.status - if self.compress: - writer.enable_compression(self.compress) + def raise_for_status(self) -> None: + if not self.ok: + # reason should always be not None for a started response + assert self.reason is not None - if self.chunked is not None: - writer.enable_chunking() + # If we're in a context we can rely on __aexit__() to release as the + # exception propagates. + if not self._in_context: + self.release() - # set default content-type - if ( - self.method in self.POST_METHODS - and ( - self._skip_auto_headers is None - or hdrs.CONTENT_TYPE not in self._skip_auto_headers + raise ClientResponseError( + self.request_info, + self.history, + status=self.status, + message=self.reason, + headers=self.headers, ) - and hdrs.CONTENT_TYPE not in self.headers - ): - self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" - - v = self.version - if hdrs.CONNECTION not in self.headers: - if conn._connector.force_close: - if v == HttpVersion11: - self.headers[hdrs.CONNECTION] = "close" - elif v == HttpVersion10: - self.headers[hdrs.CONNECTION] = "keep-alive" - - # status + headers - status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" - - # Buffer headers for potential coalescing with body - await writer.write_headers(status_line, self.headers) - task: Optional["asyncio.Task[None]"] - if self.body or self._continue is not None or protocol.writing_paused: - coro = self.write_bytes(writer, conn, self._get_content_length()) - if sys.version_info >= (3, 12): - # Optimization for Python 3.12, try to write - # bytes immediately to avoid having to schedule - # the task on the event loop. - task = asyncio.Task(coro, loop=self.loop, eager_start=True) - else: - task = self.loop.create_task(coro) - if task.done(): - task = None + def _release_connection(self) -> None: + if self._connection is not None: + if self.__writer is None: + self._connection.release() + self._connection = None else: - self._writer = task - else: - # We have nothing to write because - # - there is no body - # - the protocol does not have writing paused - # - we are not waiting for a 100-continue response - protocol.start_timeout() - writer.set_eof() - task = None - response_class = self.response_class - assert response_class is not None - self.response = response_class( - self.method, - self.original_url, - writer=task, - continue100=self._continue, - timer=self._timer, - request_info=self.request_info, - traces=self._traces, - loop=self.loop, - session=self._session, - ) - return self.response + self.__writer.add_done_callback(lambda f: self._release_connection()) - async def close(self) -> None: + async def _wait_released(self) -> None: if self.__writer is not None: try: await self.__writer @@ -817,505 +549,775 @@ async def close(self) -> None: and task.cancelling() ): raise + self._release_connection() - def terminate(self) -> None: + def _cleanup_writer(self) -> None: if self.__writer is not None: - if not self.loop.is_closed(): - self.__writer.cancel() - self.__writer.remove_done_callback(self.__reset_writer) - self.__writer = None + self.__writer.cancel() + self._session = None - async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: - for trace in self._traces: - await trace.send_request_chunk_sent(method, url, chunk) + def _notify_content(self) -> None: + content = self.content + # content can be None here, but the types are cheated elsewhere. + if content and content.exception() is None: # type: ignore[truthy-bool] + set_exception(content, _CONNECTION_CLOSED_EXCEPTION) + self._released = True - async def _on_headers_request_sent( - self, method: str, url: URL, headers: "CIMultiDict[str]" - ) -> None: - for trace in self._traces: - await trace.send_request_headers(method, url, headers) + async def wait_for_close(self) -> None: + if self.__writer is not None: + try: + await self.__writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise + self.release() + async def read(self) -> bytes: + """Read response payload.""" + if self._body is None: + try: + self._body = await self.content.read() + for trace in self._traces: + await trace.send_response_chunk_received( + self.method, self.url, self._body + ) + except BaseException: + self.close() + raise + elif self._released: # Response explicitly released + raise ClientConnectionError("Connection closed") -_CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") + protocol = self._connection and self._connection.protocol + if protocol is None or not protocol.upgraded: + await self._wait_released() # Underlying connection released + return self._body + def get_encoding(self) -> str: + ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() + mimetype = parse_mimetype(ctype) -class ClientResponse(HeadersMixin): - # Some of these attributes are None when created, - # but will be set by the start() method. - # As the end user will likely never see the None values, we cheat the types below. - # from the Status-Line of the response - version: Optional[HttpVersion] = None # HTTP-Version - status: int = None # type: ignore[assignment] # Status-Code - reason: Optional[str] = None # Reason-Phrase + encoding = mimetype.parameters.get("charset") + if encoding: + with contextlib.suppress(LookupError, ValueError): + return codecs.lookup(encoding).name - content: StreamReader = None # type: ignore[assignment] # Payload stream - _body: Optional[bytes] = None - _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] - _history: Tuple["ClientResponse", ...] = () - _raw_headers: RawHeaders = None # type: ignore[assignment] + if mimetype.type == "application" and ( + mimetype.subtype == "json" or mimetype.subtype == "rdap" + ): + # RFC 7159 states that the default encoding is UTF-8. + # RFC 7483 defines application/rdap+json + return "utf-8" - _connection: Optional["Connection"] = None # current connection - _cookies: Optional[SimpleCookie] = None - _continue: Optional["asyncio.Future[bool]"] = None - _source_traceback: Optional[traceback.StackSummary] = None - _session: Optional["ClientSession"] = None - # set up by ClientRequest after ClientResponse object creation - # post-init stage allows to not change ctor signature - _closed = True # to allow __del__ for non-initialized properly response - _released = False - _in_context = False + if self._body is None: + raise RuntimeError( + "Cannot compute fallback encoding of a not yet read body" + ) - _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" + return self._resolve_charset(self, self._body) - __writer: Optional["asyncio.Task[None]"] = None + async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: + """Read response payload and decode.""" + await self.read() - def __init__( + if encoding is None: + encoding = self.get_encoding() + + return self._body.decode(encoding, errors=errors) # type: ignore[union-attr] + + async def json( self, - method: str, - url: URL, *, - writer: "Optional[asyncio.Task[None]]", - continue100: Optional["asyncio.Future[bool]"], - timer: Optional[BaseTimerContext], - request_info: RequestInfo, - traces: List["Trace"], - loop: asyncio.AbstractEventLoop, - session: "ClientSession", + encoding: Optional[str] = None, + loads: JSONDecoder = DEFAULT_JSON_DECODER, + content_type: Optional[str] = "application/json", + ) -> Any: + """Read and decodes JSON response.""" + await self.read() + + if content_type: + if not is_expected_content_type(self.content_type, content_type): + raise ContentTypeError( + self.request_info, + self.history, + status=self.status, + message=( + "Attempt to decode JSON with " + "unexpected mimetype: %s" % self.content_type + ), + headers=self.headers, + ) + + if encoding is None: + encoding = self.get_encoding() + + return loads(self._body.decode(encoding)) # type: ignore[union-attr] + + async def __aenter__(self) -> "ClientResponse": + self._in_context = True + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: - # URL forbids subclasses, so a simple type check is enough. - assert type(url) is URL + self._in_context = False + # similar to _RequestContextManager, we do not need to check + # for exceptions, response object can close connection + # if state is broken + self.release() + await self.wait_for_close() - self.method = method - self._real_url = url - self._url = url.with_fragment(None) if url.raw_fragment else url - if writer is not None: - self._writer = writer - if continue100 is not None: - self._continue = continue100 - self._request_info = request_info - self._timer = timer if timer is not None else TimerNoop() - self._cache: Dict[str, Any] = {} - self._traces = traces - self._loop = loop - # Save reference to _resolve_charset, so that get_encoding() will still - # work after the response has finished reading the body. - # TODO: Fix session=None in tests (see ClientRequest.__init__). - if session is not None: - # store a reference to session #1985 - self._session = session - self._resolve_charset = session._resolve_charset +class ClientRequestBase: + """An internal class for proxy requests.""" + + POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} + + auth = None + proxy: Optional[URL] = None + response_class = ClientResponse + server_hostname: Optional[str] = None # Needed in connector.py + version = HttpVersion11 + _response = None + + # These class defaults help create_autospec() work correctly. + # If autospec is improved in future, maybe these can be removed. + url = URL() + method = "GET" + + __writer: Optional["asyncio.Task[None]"] = None # async task for streaming data + + _skip_auto_headers: Optional["CIMultiDict[None]"] = None + + # N.B. + # Adding __del__ method with self._writer closing doesn't make sense + # because _writer is instance method, thus it keeps a reference to self. + # Until writer has finished finalizer will not be called. + + def __init__( + self, + method: str, + url: URL, + *, + headers: CIMultiDict[str], + auth: Optional[BasicAuth], + loop: asyncio.AbstractEventLoop, + ssl: Union[SSLContext, bool, Fingerprint], + trust_env: bool = False, + ): + if match := _CONTAINS_CONTROL_CHAR_RE.search(method): + raise ValueError( + f"Method cannot contain non-token characters {method!r} " + f"(found at least {match.group()!r})" + ) + # URL forbids subclasses, so a simple type check is enough. + assert type(url) is URL, url + self.original_url = url + self.url = url.with_fragment(None) if url.raw_fragment else url + self.method = method.upper() + self.loop = loop + self._ssl = ssl + if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._update_host(url) + self._update_headers(headers) + self._update_auth(auth, trust_env) + def __reset_writer(self, _: object = None) -> None: self.__writer = None @property def _writer(self) -> Optional["asyncio.Task[None]"]: - """The writer task for streaming data. - - _writer is only provided for backwards compatibility - for subclasses that may need to access it. - """ return self.__writer @_writer.setter - def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: - """Set the writer task for streaming data.""" + def _writer(self, writer: "asyncio.Task[None]") -> None: if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - if writer is None: - return - if writer.done(): - # The writer is already done, so we can clear it immediately. - self.__writer = None - else: - writer.add_done_callback(self.__reset_writer) + writer.add_done_callback(self.__reset_writer) @property - def cookies(self) -> SimpleCookie: - if self._cookies is None: - self._cookies = SimpleCookie() - return self._cookies + def connection_key(self) -> ConnectionKey: # type: ignore[misc] + url = self.url + return tuple.__new__( + ConnectionKey, + ( + url.raw_host or "", + url.port, + url.scheme in _SSL_SCHEMES, + self._ssl, + None, + None, + None, + ), + ) - @cookies.setter - def cookies(self, cookies: SimpleCookie) -> None: - self._cookies = cookies + @property + def _request_info(self) -> RequestInfo: + headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) + # These are created on every request, so we use a NamedTuple + # for performance reasons. We don't use the RequestInfo.__new__ + # method because it has a different signature which is provided + # for backwards compatibility only. + return tuple.__new__( + RequestInfo, (self.url, self.method, headers, self.original_url) + ) - @reify - def url(self) -> URL: - return self._url + def is_ssl(self) -> bool: + return self.url.scheme in _SSL_SCHEMES - @reify - def real_url(self) -> URL: - return self._real_url + @property + def ssl(self) -> Union["SSLContext", bool, Fingerprint]: + return self._ssl - @reify - def host(self) -> str: - assert self._url.host is not None - return self._url.host + def _update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: + """Set basic auth.""" + if auth is None: + auth = self.auth + if auth is None and trust_env and self.url.host is not None: + netrc_obj = netrc_from_env() + with contextlib.suppress(LookupError): + auth = basicauth_from_netrc(netrc_obj, self.url.host) + if auth is None: + return - @reify - def headers(self) -> "CIMultiDictProxy[str]": - return self._headers + if not isinstance(auth, BasicAuth): + raise TypeError("BasicAuth() tuple is required instead") - @reify - def raw_headers(self) -> RawHeaders: - return self._raw_headers + self.headers[hdrs.AUTHORIZATION] = auth.encode() - @reify - def request_info(self) -> RequestInfo: - return self._request_info + def _update_headers(self, headers: CIMultiDict[str]) -> None: + """Update request headers.""" + self.headers: CIMultiDict[str] = CIMultiDict() - @reify - def content_disposition(self) -> Optional[ContentDisposition]: - raw = self._headers.get(hdrs.CONTENT_DISPOSITION) - if raw is None: - return None - disposition_type, params_dct = multipart.parse_content_disposition(raw) - params = MappingProxyType(params_dct) - filename = multipart.content_disposition_filename(params) - return ContentDisposition(disposition_type, params, filename) + # Build the host header + host = self.url.host_port_subcomponent - def __del__(self, _warnings: Any = warnings) -> None: - if self._closed: - return + # host_port_subcomponent is None when the URL is a relative URL. + # but we know we do not have a relative URL here. + assert host is not None + self.headers[hdrs.HOST] = headers.pop(hdrs.HOST, host) + self.headers.extend(headers) - if self._connection is not None: - self._connection.release() - self._cleanup_writer() + def _update_host(self, url: URL) -> None: + """Update destination host, port and connection type (ssl).""" + # get host/port + if not url.raw_host: + raise InvalidURL(url) - if self._loop.get_debug(): - _warnings.warn( - f"Unclosed response {self!r}", ResourceWarning, source=self - ) - context = {"client_response": self, "message": "Unclosed response"} - if self._source_traceback: - context["source_traceback"] = self._source_traceback - self._loop.call_exception_handler(context) + # basic auth info + if url.raw_user or url.raw_password: + self.auth = BasicAuth(url.user or "", url.password or "") - def __repr__(self) -> str: - out = io.StringIO() - ascii_encodable_url = str(self.url) - if self.reason: - ascii_encodable_reason = self.reason.encode( - "ascii", "backslashreplace" - ).decode("ascii") - else: - ascii_encodable_reason = "None" - print( - "".format( - ascii_encodable_url, self.status, ascii_encodable_reason - ), - file=out, - ) - print(self.headers, file=out) - return out.getvalue() + def _get_content_length(self) -> Optional[int]: + """Extract and validate Content-Length header value. - @property - def connection(self) -> Optional["Connection"]: - return self._connection + Returns parsed Content-Length value or None if not set. + Raises ValueError if header exists but cannot be parsed as an integer. + """ + if hdrs.CONTENT_LENGTH not in self.headers: + return None - @reify - def history(self) -> Tuple["ClientResponse", ...]: - """A sequence of responses, if redirects occurred.""" - return self._history + content_length_hdr = self.headers[hdrs.CONTENT_LENGTH] + try: + return int(content_length_hdr) + except ValueError: + raise ValueError( + f"Invalid Content-Length header: {content_length_hdr}" + ) from None - @reify - def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": - links_str = ", ".join(self.headers.getall("link", [])) - - if not links_str: - return MultiDictProxy(MultiDict()) - - links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict() - - for val in re.split(r",(?=\s*<)", links_str): - match = re.match(r"\s*<(.*)>(.*)", val) - if match is None: # Malformed link - continue - url, params_str = match.groups() - params = params_str.split(";")[1:] + def _create_response(self, task: Optional[asyncio.Task[None]]) -> ClientResponse: + return self.response_class( + self.method, + self.original_url, + writer=task, + continue100=None, + timer=TimerNoop(), + request_info=self._request_info, + traces=(), + loop=self.loop, + session=None, + ) - link: MultiDict[Union[str, URL]] = MultiDict() + def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: + return StreamWriter(protocol, self.loop) - for param in params: - match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) - if match is None: # Malformed param - continue - key, _, value, _ = match.groups() + def _should_write(self, protocol: BaseProtocol) -> bool: + return protocol.writing_paused - link.add(key, value) + async def _send(self, conn: "Connection") -> "ClientResponse": + # Specify request target: + # - CONNECT request must send authority form URI + # - not CONNECT proxy must send absolute form URI + # - most common is origin form URI + if self.method == hdrs.METH_CONNECT: + connect_host = self.url.host_subcomponent + assert connect_host is not None + path = f"{connect_host}:{self.url.port}" + elif self.proxy and not self.is_ssl(): + path = str(self.url) + else: + path = self.url.raw_path_qs - key = link.get("rel", url) + protocol = conn.protocol + assert protocol is not None + writer = self._create_writer(protocol) - link.add("url", self.url.join(URL(url))) + # set default content-type + if ( + self.method in self.POST_METHODS + and ( + self._skip_auto_headers is None + or hdrs.CONTENT_TYPE not in self._skip_auto_headers + ) + and hdrs.CONTENT_TYPE not in self.headers + ): + self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" - links.add(str(key), MultiDictProxy(link)) + v = self.version + if hdrs.CONNECTION not in self.headers: + if conn._connector.force_close: + if v == HttpVersion11: + self.headers[hdrs.CONNECTION] = "close" + elif v == HttpVersion10: + self.headers[hdrs.CONNECTION] = "keep-alive" - return MultiDictProxy(links) + # status + headers + status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" - async def start(self, connection: "Connection") -> "ClientResponse": - """Start response processing.""" - self._closed = False - self._protocol = connection.protocol - self._connection = connection + # Buffer headers for potential coalescing with body + await writer.write_headers(status_line, self.headers) - with self._timer: - while True: - # read response - try: - protocol = self._protocol - message, payload = await protocol.read() # type: ignore[union-attr] - except http.HttpProcessingError as exc: - raise ClientResponseError( - self.request_info, - self.history, - status=exc.code, - message=exc.message, - headers=exc.headers, - ) from exc + task: Optional["asyncio.Task[None]"] + if self._should_write(protocol): + coro = self._write_bytes(writer, conn, self._get_content_length()) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to write + # bytes immediately to avoid having to schedule + # the task on the event loop. + task = asyncio.Task(coro, loop=self.loop, eager_start=True) + else: + task = self.loop.create_task(coro) + if task.done(): + task = None + else: + self._writer = task + else: + # We have nothing to write because + # - there is no body + # - the protocol does not have writing paused + # - we are not waiting for a 100-continue response + protocol.start_timeout() + writer.set_eof() + task = None + self._response = self._create_response(task) + return self._response - if message.code < 100 or message.code > 199 or message.code == 101: - break + async def _write_bytes( + self, + writer: AbstractStreamWriter, + conn: "Connection", + content_length: Optional[int], + ) -> None: + # Base class never has a body, this will never be run. + assert False - if self._continue is not None: - set_result(self._continue, True) - self._continue = None - # payload eof handler - payload.on_eof(self._response_eof) +class ClientRequest(ClientRequestBase): + body = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) + _continue = None # waiter future for '100 Continue' response - # response status - self.version = message.version - self.status = message.code - self.reason = message.reason + GET_METHODS = { + hdrs.METH_GET, + hdrs.METH_HEAD, + hdrs.METH_OPTIONS, + hdrs.METH_TRACE, + } + DEFAULT_HEADERS = { + hdrs.ACCEPT: "*/*", + hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), + } - # headers - self._headers = message.headers # type is CIMultiDictProxy - self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] + def __init__( + self, + method: str, + url: URL, + *, + params: Query, + headers: CIMultiDict[str], + skip_auto_headers: Optional[Iterable[str]], + data: Any, + cookies: BaseCookie[str], + auth: Optional[BasicAuth], + version: HttpVersion, + compress: Union[str, bool], + chunked: Optional[bool], + expect100: bool, + loop: asyncio.AbstractEventLoop, + response_class: type["ClientResponse"], + proxy: Optional[URL], + proxy_auth: Optional[BasicAuth], + timer: BaseTimerContext, + session: "ClientSession", + ssl: Union[SSLContext, bool, Fingerprint], + proxy_headers: Optional[CIMultiDict[str]], + traces: list["Trace"], + trust_env: bool, + server_hostname: Optional[str], + **kwargs: object, + ): + # kwargs exists so authors of subclasses should expect to pass through unknown + # arguments. This allows us to safely add new arguments in future releases. + # But, we should never receive unknown arguments here in the parent class, this + # would indicate an argument has been named wrong or similar in the subclass. + assert not kwargs, "Unexpected arguments to ClientRequest" - # payload - self.content = payload + if params: + url = url.extend_query(params) + super().__init__(method, url, headers=headers, auth=auth, loop=loop, ssl=ssl) - # cookies - if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()): - cookies = SimpleCookie() - for hdr in cookie_hdrs: - try: - cookies.load(hdr) - except CookieError as exc: - client_logger.warning("Can not load response cookies: %s", exc) - self._cookies = cookies - return self + if proxy is not None: + assert type(proxy) is URL, proxy + self._session = session + self.chunked = chunked + self.response_class = response_class + self._timer = timer + self.server_hostname = server_hostname + self.version = version - def _response_eof(self) -> None: - if self._closed: - return + self._update_auto_headers(skip_auto_headers) + self._update_cookies(cookies) + self._update_content_encoding(data, compress) + self._update_proxy(proxy, proxy_auth, proxy_headers) - # protocol could be None because connection could be detached - protocol = self._connection and self._connection.protocol - if protocol is not None and protocol.upgraded: - return + self._update_body_from_data(data) + if data is not None or self.method not in self.GET_METHODS: + self._update_transfer_encoding() + self._update_expect_continue(expect100) + self._traces = traces - self._closed = True - self._cleanup_writer() - self._release_connection() + @property + def skip_auto_headers(self) -> CIMultiDict[None]: + return self._skip_auto_headers or CIMultiDict() @property - def closed(self) -> bool: - return self._closed + def connection_key(self) -> ConnectionKey: # type: ignore[misc] + if proxy_headers := self.proxy_headers: + h: Optional[int] = hash(tuple(proxy_headers.items())) + else: + h = None + url = self.url + return tuple.__new__( + ConnectionKey, + ( + url.raw_host or "", + url.port, + url.scheme in _SSL_SCHEMES, + self._ssl, + self.proxy, + self.proxy_auth, + h, + ), + ) - def close(self) -> None: - if not self._released: - self._notify_content() + @property + def session(self) -> "ClientSession": + """Return the ClientSession instance. - self._closed = True - if self._loop.is_closed(): - return + This property provides access to the ClientSession that initiated + this request, allowing middleware to make additional requests + using the same session. + """ + return self._session - self._cleanup_writer() - if self._connection is not None: - self._connection.close() - self._connection = None + def _update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None: + if skip_auto_headers is not None: + self._skip_auto_headers = CIMultiDict( + (hdr, None) for hdr in sorted(skip_auto_headers) + ) + used_headers = self.headers.copy() + used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type] + else: + # Fast path when there are no headers to skip + # which is the most common case. + used_headers = self.headers - def release(self) -> None: - if not self._released: - self._notify_content() + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in used_headers: + self.headers[hdr] = val - self._closed = True + if hdrs.USER_AGENT not in used_headers: + self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE - self._cleanup_writer() - self._release_connection() + def _update_cookies(self, cookies: BaseCookie[str]) -> None: + """Update request cookies header.""" + if not cookies: + return - @property - def ok(self) -> bool: - """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. + c = SimpleCookie() + if hdrs.COOKIE in self.headers: + c.load(self.headers.get(hdrs.COOKIE, "")) + del self.headers[hdrs.COOKIE] - This is **not** a check for ``200 OK`` but a check that the response - status is under 400. - """ - return 400 > self.status + for name, value in cookies.items(): + # Preserve coded_value + mrsl_val = value.get(value.key, Morsel()) + mrsl_val.set(value.key, value.value, value.coded_value) + c[name] = mrsl_val - def raise_for_status(self) -> None: - if not self.ok: - # reason should always be not None for a started response - assert self.reason is not None + self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() - # If we're in a context we can rely on __aexit__() to release as the - # exception propagates. - if not self._in_context: - self.release() + def _update_content_encoding(self, data: Any, compress: Union[bool, str]) -> None: + """Set request content encoding.""" + self.compress = None + if not data: + return - raise ClientResponseError( - self.request_info, - self.history, - status=self.status, - message=self.reason, - headers=self.headers, - ) + if self.headers.get(hdrs.CONTENT_ENCODING): + if compress: + raise ValueError( + "compress can not be set if Content-Encoding header is set" + ) + elif compress: + self.compress = compress if isinstance(compress, str) else "deflate" + self.headers[hdrs.CONTENT_ENCODING] = self.compress + self.chunked = True # enable chunked, no need to deal with length - def _release_connection(self) -> None: - if self._connection is not None: - if self.__writer is None: - self._connection.release() - self._connection = None - else: - self.__writer.add_done_callback(lambda f: self._release_connection()) + def _update_transfer_encoding(self) -> None: + """Analyze transfer-encoding header.""" + te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() - async def _wait_released(self) -> None: - if self.__writer is not None: - try: - await self.__writer - except asyncio.CancelledError: - if ( - sys.version_info >= (3, 11) - and (task := asyncio.current_task()) - and task.cancelling() - ): - raise - self._release_connection() + if "chunked" in te: + if self.chunked: + raise ValueError( + "chunked can not be set " + 'if "Transfer-Encoding: chunked" header is set' + ) - def _cleanup_writer(self) -> None: - if self.__writer is not None: - self.__writer.cancel() - self._session = None + elif self.chunked: + if hdrs.CONTENT_LENGTH in self.headers: + raise ValueError( + "chunked can not be set if Content-Length header is set" + ) - def _notify_content(self) -> None: - content = self.content - # content can be None here, but the types are cheated elsewhere. - if content and content.exception() is None: # type: ignore[truthy-bool] - set_exception(content, _CONNECTION_CLOSED_EXCEPTION) - self._released = True + self.headers[hdrs.TRANSFER_ENCODING] = "chunked" - async def wait_for_close(self) -> None: - if self.__writer is not None: - try: - await self.__writer - except asyncio.CancelledError: - if ( - sys.version_info >= (3, 11) - and (task := asyncio.current_task()) - and task.cancelling() - ): - raise - self.release() + def _update_body_from_data(self, body: Any) -> None: + if body is None: + self.headers[hdrs.CONTENT_LENGTH] = "0" + return - async def read(self) -> bytes: - """Read response payload.""" - if self._body is None: + # FormData + if isinstance(body, FormData): + body = body() + else: try: - self._body = await self.content.read() - for trace in self._traces: - await trace.send_response_chunk_received( - self.method, self.url, self._body - ) - except BaseException: - self.close() - raise - elif self._released: # Response explicitly released - raise ClientConnectionError("Connection closed") + body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) + except payload.LookupError: + boundary = None + if hdrs.CONTENT_TYPE in self.headers: + boundary = parse_mimetype( + self.headers[hdrs.CONTENT_TYPE] + ).parameters.get("boundary") + body = FormData(body, boundary=boundary)() - protocol = self._connection and self._connection.protocol - if protocol is None or not protocol.upgraded: - await self._wait_released() # Underlying connection released - return self._body + self.body = body - def get_encoding(self) -> str: - ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() - mimetype = helpers.parse_mimetype(ctype) + # enable chunked encoding if needed + if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: + if (size := body.size) is not None: + self.headers[hdrs.CONTENT_LENGTH] = str(size) + else: + self.chunked = True - encoding = mimetype.parameters.get("charset") - if encoding: - with contextlib.suppress(LookupError, ValueError): - return codecs.lookup(encoding).name + # copy payload headers + assert body.headers + headers = self.headers + skip_headers = self._skip_auto_headers + for key, value in body.headers.items(): + if key in headers or (skip_headers is not None and key in skip_headers): + continue + headers[key] = value - if mimetype.type == "application" and ( - mimetype.subtype == "json" or mimetype.subtype == "rdap" + def _update_expect_continue(self, expect: bool = False) -> None: + if expect: + self.headers[hdrs.EXPECT] = "100-continue" + elif ( + hdrs.EXPECT in self.headers + and self.headers[hdrs.EXPECT].lower() == "100-continue" ): - # RFC 7159 states that the default encoding is UTF-8. - # RFC 7483 defines application/rdap+json - return "utf-8" + expect = True - if self._body is None: - raise RuntimeError( - "Cannot compute fallback encoding of a not yet read body" - ) + if expect: + self._continue = self.loop.create_future() - return self._resolve_charset(self, self._body) + def _update_proxy( + self, + proxy: Optional[URL], + proxy_auth: Optional[BasicAuth], + proxy_headers: Optional[CIMultiDict[str]], + ) -> None: + self.proxy = proxy + if proxy is None: + self.proxy_auth = None + self.proxy_headers = None + return - async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: - """Read response payload and decode.""" - await self.read() + if proxy_auth and not isinstance(proxy_auth, BasicAuth): + raise ValueError("proxy_auth must be None or BasicAuth() tuple") + self.proxy_auth = proxy_auth + self.proxy_headers = proxy_headers - if encoding is None: - encoding = self.get_encoding() + def _create_response(self, task: Optional[asyncio.Task[None]]) -> ClientResponse: + return self.response_class( + self.method, + self.original_url, + writer=task, + continue100=self._continue, + timer=self._timer, + request_info=self._request_info, + traces=self._traces, + loop=self.loop, + session=self._session, + ) - return self._body.decode(encoding, errors=errors) # type: ignore[union-attr] + def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: + writer = StreamWriter( + protocol, + self.loop, + on_chunk_sent=( + functools.partial(self._on_chunk_request_sent, self.method, self.url) + if self._traces + else None + ), + on_headers_sent=( + functools.partial(self._on_headers_request_sent, self.method, self.url) + if self._traces + else None + ), + ) - async def json( - self, - *, - encoding: Optional[str] = None, - loads: JSONDecoder = DEFAULT_JSON_DECODER, - content_type: Optional[str] = "application/json", - ) -> Any: - """Read and decodes JSON response.""" - await self.read() + if self.compress: + writer.enable_compression(self.compress) - if content_type: - if not is_expected_content_type(self.content_type, content_type): - raise ContentTypeError( - self.request_info, - self.history, - status=self.status, - message=( - "Attempt to decode JSON with " - "unexpected mimetype: %s" % self.content_type - ), - headers=self.headers, - ) + if self.chunked is not None: + writer.enable_chunking() + return writer - if encoding is None: - encoding = self.get_encoding() + def _should_write(self, protocol: BaseProtocol) -> bool: + return ( + self.body.size != 0 or self._continue is not None or protocol.writing_paused + ) - return loads(self._body.decode(encoding)) # type: ignore[union-attr] + async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: + for trace in self._traces: + await trace.send_request_chunk_sent(method, url, chunk) - async def __aenter__(self) -> "ClientResponse": - self._in_context = True - return self + async def _on_headers_request_sent( + self, method: str, url: URL, headers: "CIMultiDict[str]" + ) -> None: + for trace in self._traces: + await trace.send_request_headers(method, url, headers) - async def __aexit__( + async def _write_bytes( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + writer: AbstractStreamWriter, + conn: "Connection", + content_length: Optional[int], ) -> None: - self._in_context = False - # similar to _RequestContextManager, we do not need to check - # for exceptions, response object can close connection - # if state is broken - self.release() - await self.wait_for_close() + """ + Write the request body to the connection stream. + + This method handles writing different types of request bodies: + 1. Payload objects (using their specialized write_with_length method) + 2. Bytes/bytearray objects + 3. Iterable body content + + Args: + writer: The stream writer to write the body to + conn: The connection being used for this request + content_length: Optional maximum number of bytes to write from the body + (None means write the entire body) + + The method properly handles: + - Waiting for 100-Continue responses if required + - Content length constraints for chunked encoding + - Error handling for network issues, cancellation, and other exceptions + - Signaling EOF and timeout management + + Raises: + ClientOSError: When there's an OS-level error writing the body + ClientConnectionError: When there's a general connection error + asyncio.CancelledError: When the operation is cancelled + + """ + # 100 response + if self._continue is not None: + # Force headers to be sent before waiting for 100-continue + writer.send_headers() + await writer.drain() + await self._continue + + protocol = conn.protocol + assert protocol is not None + try: + await self.body.write_with_length(writer, content_length) + except OSError as underlying_exc: + reraised_exc = underlying_exc + + # Distinguish between timeout and other OS errors for better error reporting + exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( + underlying_exc, asyncio.TimeoutError + ) + if exc_is_not_timeout: + reraised_exc = ClientOSError( + underlying_exc.errno, + f"Can not write request body for {self.url !s}", + ) + + set_exception(protocol, reraised_exc, underlying_exc) + except asyncio.CancelledError: + # Body hasn't been fully sent, so connection can't be reused + conn.close() + raise + except Exception as underlying_exc: + set_exception( + protocol, + ClientConnectionError( + "Failed to send bytes into the underlying connection " + f"{conn !s}: {underlying_exc!r}", + ), + underlying_exc, + ) + else: + # Successfully wrote the body, signal EOF and start response timeout + await writer.write_eof() + protocol.start_timeout() + + async def _close(self) -> None: + if self.__writer is not None: + try: + await self.__writer + except asyncio.CancelledError: + if ( + sys.version_info >= (3, 11) + and (task := asyncio.current_task()) + and task.cancelling() + ): + raise + + def _terminate(self) -> None: + if self.__writer is not None: + if not self.loop.is_closed(): + self.__writer.cancel() + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = None diff --git a/aiohttp/connector.py b/aiohttp/connector.py index c525ed92191..ee963e7cffe 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -34,6 +34,7 @@ import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType +from multidict import CIMultiDict from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -51,7 +52,12 @@ ssl_errors, ) from .client_proto import ResponseHandler -from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint +from .client_reqrep import ( + SSL_ALLOWED_TYPES, + ClientRequest, + ClientRequestBase, + Fingerprint, +) from .helpers import ( _SENTINEL, ceil_timeout, @@ -1068,7 +1074,7 @@ async def _create_connection( return proto - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequestBase) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1101,7 +1107,7 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: return _SSL_CONTEXT_UNVERIFIED return _SSL_CONTEXT_VERIFIED - def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: + def _get_fingerprint(self, req: ClientRequestBase) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -1114,7 +1120,7 @@ async def _wrap_create_connection( self, *args: Any, addr_infos: List[AddrInfoType], - req: ClientRequest, + req: ClientRequestBase, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, **kwargs: Any, @@ -1147,7 +1153,7 @@ def _warn_about_tls_in_tls( req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" - if req.request_info.url.scheme != "https": + if req.url.scheme != "https": return # Check if uvloop is being used, which supports TLS in TLS, @@ -1208,7 +1214,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.host, + server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1242,7 +1248,7 @@ async def _start_tls_connection( raise ClientConnectionError( "Cannot initialize a TLS-in-TLS connection to host " - f"{req.host!s}:{req.port:d} through an underlying connection " + f"{req.url.host!s}:{req.url.port:d} through an underlying connection " f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " f"[{type_err!s}]" ) from type_err @@ -1279,7 +1285,7 @@ def _convert_hosts_to_addr_infos( async def _create_direct_connection( self, - req: ClientRequest, + req: ClientRequestBase, traces: List["Trace"], timeout: "ClientTimeout", *, @@ -1295,7 +1301,7 @@ async def _create_direct_connection( # See https://github.com/aio-libs/aiohttp/pull/7364. if host.endswith(".."): host = host.rstrip(".") + "." - port = req.port + port = req.url.port assert port is not None try: # Cancelling this lookup should not cancel the underlying lookup @@ -1354,14 +1360,12 @@ async def _create_direct_connection( async def _create_proxy_connection( self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: - headers: Dict[str, str] = {} - if req.proxy_headers is not None: - headers = req.proxy_headers # type: ignore[assignment] + headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers headers[hdrs.HOST] = req.headers[hdrs.HOST] url = req.proxy assert url is not None - proxy_req = ClientRequest( + proxy_req = ClientRequestBase( hdrs.METH_GET, url, headers=headers, @@ -1400,7 +1404,7 @@ async def _create_proxy_connection( proxy=None, proxy_auth=None, proxy_headers_hash=None ) conn = Connection(self, key, proto, self._loop) - proxy_resp = await proxy_req.send(conn) + proxy_resp = await proxy_req._send(conn) try: protocol = conn._protocol assert protocol is not None diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 48cf0d2f229..444a52e9954 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -235,7 +235,7 @@ def clone( return self.__class__( message, self._payload, - self._protocol, # type: ignore[arg-type] + self._protocol, self._payload_writer, self._task, self._loop, diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index c430ef3a49c..19a0d13d3c8 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -2,19 +2,45 @@ import asyncio from http.cookies import BaseCookie -from typing import Union +from typing import Any, Union from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture from yarl import URL -from aiohttp.client_reqrep import ClientRequest, ClientResponse +from aiohttp.client_reqrep import ClientRequest as RawClientRequest, ClientResponse from aiohttp.cookiejar import CookieJar from aiohttp.helpers import TimerNoop from aiohttp.http_writer import HttpVersion11 from aiohttp.tracing import Trace +def ClientRequest(method: str, url: URL, **kwargs: Any) -> RawClientRequest: + default_args = { + "params": {}, + "headers": CIMultiDict[str](), + "skip_auto_headers": None, + "data": None, + "cookies": BaseCookie[str](), + "auth": None, + "version": HttpVersion11, + "compress": False, + "chunked": None, + "expect100": False, + "response_class": ClientResponse, + "proxy": None, + "proxy_auth": None, + "timer": TimerNoop(), + "session": None, # Shouldn't be None, but we don't have an async context. + "ssl": True, + "proxy_headers": None, + "traces": [], + "trust_env": False, + "server_hostname": None, + } + return RawClientRequest(method, url, **(default_args | kwargs)) + + def test_client_request_update_cookies( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: @@ -27,7 +53,7 @@ def test_client_request_update_cookies( @benchmark def _run() -> None: - req.update_cookies(cookies=cookies) + req._update_cookies(cookies=cookies) def test_create_client_request_with_cookies( @@ -44,7 +70,7 @@ def test_create_client_request_with_cookies( @benchmark def _run() -> None: - ClientRequest( + RawClientRequest( method="get", url=url, loop=loop, @@ -82,7 +108,7 @@ def test_create_client_request_with_headers( @benchmark def _run() -> None: - ClientRequest( + RawClientRequest( method="get", url=url, loop=loop, @@ -155,7 +181,7 @@ def __init__(self) -> None: async def send_requests() -> None: for _ in range(100): - await req.send(conn) # type: ignore[arg-type] + await req._send(conn) # type: ignore[arg-type] @benchmark def _run() -> None: diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 6b094171012..ea84d6d7917 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -35,6 +35,7 @@ ) from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection +from aiohttp.hdrs import METH_DELETE from aiohttp.http import HttpVersion10, HttpVersion11, StreamWriter from aiohttp.typedefs import LooseCookies @@ -51,6 +52,11 @@ def remove_done_callback(self, cb: Callable[[], None]) -> None: """Dummy method.""" +ALL_METHODS = frozenset( + (*ClientRequest.GET_METHODS, *ClientRequest.POST_METHODS, METH_DELETE) +) + + @pytest.fixture def make_request(loop: asyncio.AbstractEventLoop) -> Iterator[_RequestMaker]: request = None @@ -524,7 +530,7 @@ def test_cookies_merge_with_headers(make_request: _RequestMaker) -> None: def test_query_multivalued_param(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: + for meth in ALL_METHODS: req = make_request( meth, "http://python.org", params=(("test", "foo"), ("test", "baz")) ) @@ -533,19 +539,19 @@ def test_query_multivalued_param(make_request: _RequestMaker) -> None: def test_query_str_param(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: + for meth in ALL_METHODS: req = make_request(meth, "http://python.org", params="test=foo") assert str(req.url) == "http://python.org/?test=foo" def test_query_bytes_param_raises(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: + for meth in ALL_METHODS: with pytest.raises(TypeError): make_request(meth, "http://python.org", params=b"test=foo") def test_query_str_param_is_not_encoded(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: + for meth in ALL_METHODS: req = make_request(meth, "http://python.org", params="test=f+oo") assert str(req.url) == "http://python.org/?test=f+oo"