Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions vcr/stubs/httpx_stubs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import functools
import inspect
import logging
Expand Down Expand Up @@ -36,16 +35,34 @@ def _transform_headers(httpx_response):
return out


async def _to_serialized_response(resp, aread):
def _to_serialized_response_sync(resp):
# The content shouldn't already have been read in by HTTPX.
assert not hasattr(resp, "_decoder")

# Retrieve the content, but without decoding it.
with patch.dict(resp.headers, {"Content-Encoding": ""}):
if aread:
await resp.aread()
else:
resp.read()
resp.read()

result = {
"status": {"code": resp.status_code, "message": resp.reason_phrase},
"headers": _transform_headers(resp),
"body": {"string": resp.content},
}

# As the content wasn't decoded, we restore the response to a state which
# will be capable of decoding the content for the consumer.
del resp._decoder
resp._content = resp._get_content_decoder().decode(resp.content)
return result


async def _to_serialized_response(resp):
# The content shouldn't already have been read in by HTTPX.
assert not hasattr(resp, "_decoder")

# Retrieve the content, but without decoding it.
with patch.dict(resp.headers, {"Content-Encoding": ""}):
await resp.aread()

result = {
"status": {"code": resp.status_code, "message": resp.reason_phrase},
Expand Down Expand Up @@ -126,17 +143,35 @@ def _shared_vcr_send(cassette, real_send, *args, **kwargs):
return vcr_request, None


async def _record_responses(cassette, vcr_request, real_response, aread):
def _record_responses_sync(cassette, vcr_request, real_response):
if not cassette.filter_request(vcr_request):
return real_response
for past_real_response in real_response.history:
past_vcr_request = _make_vcr_request(past_real_response.request)
cassette.append(past_vcr_request, await _to_serialized_response(past_real_response, aread))
cassette.append(past_vcr_request, _to_serialized_response_sync(past_real_response))

if real_response.history:
# If there was a redirection keep we want the request which will hold the
# final redirect value
vcr_request = _make_vcr_request(real_response.request)

cassette.append(vcr_request, await _to_serialized_response(real_response, aread))
cassette.append(vcr_request, _to_serialized_response_sync(real_response))
return real_response


async def _record_responses(cassette, vcr_request, real_response):
if not cassette.filter_request(vcr_request):
return real_response
for past_real_response in real_response.history:
past_vcr_request = _make_vcr_request(past_real_response.request)
cassette.append(past_vcr_request, await _to_serialized_response(past_real_response))

if real_response.history:
# If there was a redirection keep we want the request which will hold the
# final redirect value
vcr_request = _make_vcr_request(real_response.request)

cassette.append(vcr_request, await _to_serialized_response(real_response))
return real_response


Expand All @@ -154,7 +189,7 @@ async def _async_vcr_send(cassette, real_send, *args, **kwargs):
return response

real_response = await real_send(*args, **kwargs)
await _record_responses(cassette, vcr_request, real_response, aread=True)
await _record_responses(cassette, vcr_request, real_response)
return real_response


Expand All @@ -166,22 +201,6 @@ def _inner_send(*args, **kwargs):
return _inner_send


def _run_async_function(sync_func, *args, **kwargs):
"""
Safely run an asynchronous function from a synchronous context.
Handles both cases:
- An event loop is already running.
- No event loop exists yet.
"""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(sync_func(*args, **kwargs))
else:
# If inside a running loop, create a task and wait for it
return asyncio.ensure_future(sync_func(*args, **kwargs))


def _sync_vcr_send(cassette, real_send, *args, **kwargs):
vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
if response:
Expand All @@ -190,7 +209,7 @@ def _sync_vcr_send(cassette, real_send, *args, **kwargs):
return response

real_response = real_send(*args, **kwargs)
_run_async_function(_record_responses, cassette, vcr_request, real_response, aread=False)
_record_responses_sync(cassette, vcr_request, real_response)
return real_response


Expand Down
Loading