Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/integration/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,103 @@ def test_filter_query_parameters(tmpdir, httpbin):
cassette_content = f.read()
assert "password" not in cassette_content
assert "secret" not in cassette_content


def test_raise_for_status_enabled_client_session(tmpdir, httpbin):
async def run(loop):
url = httpbin + "/status/404"
path = str(tmpdir.join("raise_for_status.yaml"))

with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette:
async with aiohttp.ClientSession(loop=loop, raise_for_status=True) as session:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await session.get(url)
assert exc_info.value.status == 404
assert cassette.play_count == 0

with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette:
async with aiohttp.ClientSession(loop=loop, raise_for_status=True) as session:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await session.get(url)
assert exc_info.value.status == 404
assert cassette.play_count == 1

run_in_loop(run)


def test_raise_for_status_custom_client_session(tmpdir, httpbin):
async def run(loop):
url = httpbin + "/status/404"
path = str(tmpdir.join("custom_raise_for_status.yaml"))

async def custom_raise_for_status(response):
if response.status == 404:
raise aiohttp.ClientResponseError(
response.request_info,
response.history,
status=response.status,
message="Custom error",
)

with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette:
async with aiohttp.ClientSession(loop=loop, raise_for_status=custom_raise_for_status) as session:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await session.get(url)
assert exc_info.value.status == 404
assert exc_info.value.message == "Custom error"
assert cassette.play_count == 0

with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette:
async with aiohttp.ClientSession(loop=loop, raise_for_status=custom_raise_for_status) as session:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
await session.get(url)
assert exc_info.value.status == 404
assert exc_info.value.message == "Custom error"
assert cassette.play_count == 1

run_in_loop(run)


def test_raise_for_status_enabled_request(tmpdir, httpbin):
url = httpbin + "/status/404"
path = str(tmpdir.join("raise_for_status.yaml"))

with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
get(url, raise_for_status=True)
assert exc_info.value.status == 404
assert cassette.play_count == 0

with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
get(url, raise_for_status=True)
assert exc_info.value.status == 404
assert cassette.play_count == 1


def test_raise_for_status_custom_request(tmpdir, httpbin):
url = httpbin + "/status/404"
path = str(tmpdir.join("custom_raise_for_status.yaml"))

async def custom_raise_for_status(response):
if response.status == 404:
raise aiohttp.ClientResponseError(
response.request_info,
response.history,
status=response.status,
message="Custom error",
)

with vcr.use_cassette(path, record_mode=vcr.mode.ALL) as cassette:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
get(url, raise_for_status=custom_raise_for_status)
assert exc_info.value.status == 404
assert exc_info.value.message == "Custom error"
assert cassette.play_count == 0

with vcr.use_cassette(path, record_mode=vcr.mode.NONE) as cassette:
with pytest.raises(aiohttp.ClientResponseError) as exc_info:
get(url, raise_for_status=custom_raise_for_status)
assert exc_info.value.status == 404
assert exc_info.value.message == "Custom error"
assert cassette.play_count == 1
45 changes: 42 additions & 3 deletions vcr/stubs/aiohttp_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
import functools
import json
import logging
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from http.cookies import CookieError, Morsel, SimpleCookie
from typing import Union

from aiohttp import ClientConnectionError, ClientResponse, CookieJar, RequestInfo, hdrs, streams
from aiohttp import (
ClientConnectionError,
ClientResponse,
ClientSession,
CookieJar,
RequestInfo,
hdrs,
streams,
)
from aiohttp.helpers import strip_auth_from_url
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict
from yarl import URL
Expand Down Expand Up @@ -240,6 +248,28 @@ def _build_url_with_params(url_str: str, params: Mapping[str, Union[str, int, fl
return url.with_query(q)


async def _raise_for_status_stub(response: ClientResponse):
"""Stub for the raise_for_status parameter in aiohttp requests."""


async def _raise_for_status(
self: ClientSession,
raise_for_status: Callable[[ClientResponse], Awaitable[None]] | bool | None,
response: ClientResponse,
):
"""This mirrors the raise_for_status logic in
https://github.com/aio-libs/aiohttp/blob/7d56ed37752d220ca3bfd2bc753341d3c47762d8/aiohttp/client.py#L832
"""
if raise_for_status is None:
raise_for_status = self._raise_for_status
if raise_for_status is None:
pass
elif callable(raise_for_status):
await raise_for_status(response)
elif raise_for_status:
response.raise_for_status()


def vcr_request(cassette, real_request):
@functools.wraps(real_request)
async def new_request(self, method, url, **kwargs):
Expand All @@ -249,6 +279,7 @@ async def new_request(self, method, url, **kwargs):
data = kwargs.get("data", kwargs.get("json"))
params = kwargs.get("params")
cookies = kwargs.get("cookies")
raise_for_status = kwargs.pop("raise_for_status", None)

if auth is not None:
headers["AUTHORIZATION"] = auth.encode()
Expand All @@ -267,15 +298,23 @@ async def new_request(self, method, url, **kwargs):
for redirect in response.history:
self._cookie_jar.update_cookies(redirect.cookies, redirect.url)
self._cookie_jar.update_cookies(response.cookies, response.url)

await _raise_for_status(self, raise_for_status, response)

return response

if cassette.write_protected and cassette.filter_request(vcr_request):
raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)

log.info("%s not in cassette, sending to real server", vcr_request)

response = await real_request(self, method, url, **kwargs)
# Override the raise_for_status parameter to avoid raising an exception before we can record the
# response.
response = await real_request(self, method, url, **kwargs, raise_for_status=_raise_for_status_stub)
await record_responses(cassette, vcr_request, response)

await _raise_for_status(self, raise_for_status, response)

return response

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - was thinking it might be cleaner to rewrite lines 295-318 as follows, so that you don't have to repeat the _raise_for_status call?

        if cassette.can_play_response_for(vcr_request):
            log.info(f"Playing response for {vcr_request} from cassette")
            response = play_responses(cassette, vcr_request, kwargs)
            for redirect in response.history:
                self._cookie_jar.update_cookies(redirect.cookies, redirect.url)
            self._cookie_jar.update_cookies(response.cookies, response.url)
        else:
            if cassette.write_protected and cassette.filter_request(vcr_request):
                raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)

            log.info("%s not in cassette, sending to real server", vcr_request)

            response = await real_request(self, method, url, **kwargs, raise_for_status=False)
            await record_responses(cassette, vcr_request, response)

        await _raise_for_status(self, raise_for_status, response)

        return response

You could also just set raise_for_status in the real request to False, to save you from having to use the stub?


return new_request