diff --git a/tests/integration/test_aiohttp.py b/tests/integration/test_aiohttp.py index c1ae7484..abf76914 100644 --- a/tests/integration/test_aiohttp.py +++ b/tests/integration/test_aiohttp.py @@ -466,3 +466,103 @@ def test_filter_query_parameters(tmpdir, httpbin): cassette_content = f.read() assert "password" not in cassette_content assert "secret" not in cassette_content + + +def test_raise_for_status_enabled_client_session(tmpdir, httpbin): + async def run(loop): + url = httpbin + "/status/404" + path = str(tmpdir.join("raise_for_status.yaml")) + + with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette: + async with aiohttp.ClientSession(loop=loop, raise_for_status=True) as session: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + await session.get(url) + assert exc_info.value.status == 404 + assert cassette.play_count == 0 + + with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette: + async with aiohttp.ClientSession(loop=loop, raise_for_status=True) as session: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + await session.get(url) + assert exc_info.value.status == 404 + assert cassette.play_count == 1 + + run_in_loop(run) + + +def test_raise_for_status_custom_client_session(tmpdir, httpbin): + async def run(loop): + url = httpbin + "/status/404" + path = str(tmpdir.join("custom_raise_for_status.yaml")) + + async def custom_raise_for_status(response): + if response.status == 404: + raise aiohttp.ClientResponseError( + response.request_info, + response.history, + status=response.status, + message="Custom error", + ) + + with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette: + async with aiohttp.ClientSession(loop=loop, raise_for_status=custom_raise_for_status) as session: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + await session.get(url) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Custom error" + assert cassette.play_count == 0 + + with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette: + async with aiohttp.ClientSession(loop=loop, raise_for_status=custom_raise_for_status) as session: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + await session.get(url) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Custom error" + assert cassette.play_count == 1 + + run_in_loop(run) + + +def test_raise_for_status_enabled_request(tmpdir, httpbin): + url = httpbin + "/status/404" + path = str(tmpdir.join("raise_for_status.yaml")) + + with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + get(url, raise_for_status=True) + assert exc_info.value.status == 404 + assert cassette.play_count == 0 + + with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + get(url, raise_for_status=True) + assert exc_info.value.status == 404 + assert cassette.play_count == 1 + + +def test_raise_for_status_custom_request(tmpdir, httpbin): + url = httpbin + "/status/404" + path = str(tmpdir.join("custom_raise_for_status.yaml")) + + async def custom_raise_for_status(response): + if response.status == 404: + raise aiohttp.ClientResponseError( + response.request_info, + response.history, + status=response.status, + message="Custom error", + ) + + with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + get(url, raise_for_status=custom_raise_for_status) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Custom error" + assert cassette.play_count == 0 + + with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette: + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + get(url, raise_for_status=custom_raise_for_status) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Custom error" + assert cassette.play_count == 1 diff --git a/vcr/stubs/aiohttp_stubs.py b/vcr/stubs/aiohttp_stubs.py index 88be90a7..9bf6035b 100644 --- a/vcr/stubs/aiohttp_stubs.py +++ b/vcr/stubs/aiohttp_stubs.py @@ -4,11 +4,19 @@ import functools import json import logging -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from http.cookies import CookieError, Morsel, SimpleCookie from typing import Union -from aiohttp import ClientConnectionError, ClientResponse, CookieJar, RequestInfo, hdrs, streams +from aiohttp import ( + ClientConnectionError, + ClientResponse, + ClientSession, + CookieJar, + RequestInfo, + hdrs, + streams, +) from aiohttp.helpers import strip_auth_from_url from multidict import CIMultiDict, CIMultiDictProxy, MultiDict from yarl import URL @@ -240,6 +248,28 @@ def _build_url_with_params(url_str: str, params: Mapping[str, Union[str, int, fl return url.with_query(q) +async def _raise_for_status_stub(response: ClientResponse): + """Stub for the raise_for_status parameter in aiohttp requests.""" + + +async def _raise_for_status( + self: ClientSession, + raise_for_status: Callable[[ClientResponse], Awaitable[None]] | bool | None, + response: ClientResponse, +): + """This mirrors the raise_for_status logic in + https://github.com/aio-libs/aiohttp/blob/7d56ed37752d220ca3bfd2bc753341d3c47762d8/aiohttp/client.py#L832 + """ + if raise_for_status is None: + raise_for_status = self._raise_for_status + if raise_for_status is None: + pass + elif callable(raise_for_status): + await raise_for_status(response) + elif raise_for_status: + response.raise_for_status() + + def vcr_request(cassette, real_request): @functools.wraps(real_request) async def new_request(self, method, url, **kwargs): @@ -249,6 +279,7 @@ async def new_request(self, method, url, **kwargs): data = kwargs.get("data", kwargs.get("json")) params = kwargs.get("params") cookies = kwargs.get("cookies") + raise_for_status = kwargs.pop("raise_for_status", None) if auth is not None: headers["AUTHORIZATION"] = auth.encode() @@ -267,6 +298,9 @@ async def new_request(self, method, url, **kwargs): for redirect in response.history: self._cookie_jar.update_cookies(redirect.cookies, redirect.url) self._cookie_jar.update_cookies(response.cookies, response.url) + + await _raise_for_status(self, raise_for_status, response) + return response if cassette.write_protected and cassette.filter_request(vcr_request): @@ -274,8 +308,13 @@ async def new_request(self, method, url, **kwargs): log.info("%s not in cassette, sending to real server", vcr_request) - response = await real_request(self, method, url, **kwargs) + # Override the raise_for_status parameter to avoid raising an exception before we can record the + # response. + response = await real_request(self, method, url, **kwargs, raise_for_status=_raise_for_status_stub) await record_responses(cassette, vcr_request, response) + + await _raise_for_status(self, raise_for_status, response) + return response return new_request