Skip to content

Commit e249563

Browse files
committed
Stream response body in ASGITransport
Fixes #2186
1 parent cf989ae commit e249563

File tree

2 files changed

+191
-8
lines changed

2 files changed

+191
-8
lines changed

httpx/_transports/asgi.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import types
12
import typing
23

34
import sniffio
@@ -33,12 +34,75 @@ def create_event() -> "Event":
3334
return asyncio.Event()
3435

3536

37+
class _AwaitableRunner:
38+
def __init__(self, awaitable: typing.Awaitable[typing.Any]):
39+
self._generator = awaitable.__await__()
40+
self._started = False
41+
self._next_item: typing.Any = None
42+
self._finished = False
43+
44+
@types.coroutine
45+
def __call__(
46+
self, *, until: typing.Optional[typing.Callable[[], bool]] = None
47+
) -> typing.Generator[typing.Any, typing.Any, typing.Any]:
48+
while not self._finished and (until is None or not until()):
49+
send_value, throw_value = None, None
50+
if self._started:
51+
try:
52+
send_value = yield self._next_item
53+
except BaseException as e:
54+
throw_value = e
55+
56+
self._started = True
57+
try:
58+
if throw_value is not None:
59+
self._next_item = self._generator.throw(throw_value)
60+
else:
61+
self._next_item = self._generator.send(send_value)
62+
except StopIteration as e:
63+
self._finished = True
64+
return e.value
65+
except BaseException:
66+
self._generator.close()
67+
self._finished = True
68+
raise
69+
70+
3671
class ASGIResponseStream(AsyncByteStream):
37-
def __init__(self, body: typing.List[bytes]) -> None:
72+
def __init__(
73+
self,
74+
body: typing.List[bytes],
75+
raise_app_exceptions: bool,
76+
response_complete: "Event",
77+
app_runner: _AwaitableRunner,
78+
) -> None:
3879
self._body = body
80+
self._raise_app_exceptions = raise_app_exceptions
81+
self._response_complete = response_complete
82+
self._app_runner = app_runner
3983

4084
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
41-
yield b"".join(self._body)
85+
try:
86+
while bool(self._body) or not self._response_complete.is_set():
87+
if self._body:
88+
yield b"".join(self._body)
89+
self._body.clear()
90+
await self._app_runner(
91+
until=lambda: bool(self._body) or self._response_complete.is_set()
92+
)
93+
except Exception: # noqa: PIE786
94+
if self._raise_app_exceptions:
95+
raise
96+
finally:
97+
await self.aclose()
98+
99+
async def aclose(self) -> None:
100+
self._response_complete.set()
101+
try:
102+
await self._app_runner()
103+
except Exception: # noqa: PIE786
104+
if self._raise_app_exceptions:
105+
raise
42106

43107

44108
class ASGITransport(AsyncBaseTransport):
@@ -145,8 +209,10 @@ async def send(message: _Message) -> None:
145209
response_headers = message.get("headers", [])
146210
response_started = True
147211

148-
elif message["type"] == "http.response.body":
149-
assert not response_complete.is_set()
212+
elif (
213+
message["type"] == "http.response.body"
214+
and not response_complete.is_set()
215+
):
150216
body = message.get("body", b"")
151217
more_body = message.get("more_body", False)
152218

@@ -156,9 +222,11 @@ async def send(message: _Message) -> None:
156222
if not more_body:
157223
response_complete.set()
158224

225+
app_runner = _AwaitableRunner(self.app(scope, receive, send))
226+
159227
try:
160-
await self.app(scope, receive, send)
161-
except Exception: # noqa: PIE-786
228+
await app_runner(until=lambda: response_started)
229+
except Exception: # noqa: PIE786
162230
if self.raise_app_exceptions:
163231
raise
164232

@@ -168,10 +236,11 @@ async def send(message: _Message) -> None:
168236
if response_headers is None:
169237
response_headers = {}
170238

171-
assert response_complete.is_set()
172239
assert status_code is not None
173240
assert response_headers is not None
174241

175-
stream = ASGIResponseStream(body_parts)
242+
stream = ASGIResponseStream(
243+
body_parts, self.raise_app_exceptions, response_complete, app_runner
244+
)
176245

177246
return Response(status_code, headers=response_headers, stream=stream)

tests/test_asgi.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22

3+
import anyio
34
import pytest
45

56
import httpx
@@ -60,13 +61,24 @@ async def raise_exc(scope, receive, send):
6061
raise RuntimeError()
6162

6263

64+
async def raise_exc_after_response_start(scope, receive, send):
65+
status = 200
66+
output = b"Hello, World!"
67+
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
68+
69+
await send({"type": "http.response.start", "status": status, "headers": headers})
70+
await anyio.sleep(0)
71+
raise RuntimeError()
72+
73+
6374
async def raise_exc_after_response(scope, receive, send):
6475
status = 200
6576
output = b"Hello, World!"
6677
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
6778

6879
await send({"type": "http.response.start", "status": status, "headers": headers})
6980
await send({"type": "http.response.body", "body": output})
81+
await anyio.sleep(0)
7082
raise RuntimeError()
7183

7284

@@ -165,6 +177,14 @@ async def test_asgi_exc():
165177
await client.get("http://www.example.org/")
166178

167179

180+
@pytest.mark.anyio
181+
async def test_asgi_exc_after_response_start():
182+
transport = httpx.ASGITransport(app=raise_exc_after_response_start)
183+
async with httpx.AsyncClient(transport=transport) as client:
184+
with pytest.raises(RuntimeError):
185+
await client.get("http://www.example.org/")
186+
187+
168188
@pytest.mark.anyio
169189
async def test_asgi_exc_after_response():
170190
async with httpx.AsyncClient(app=raise_exc_after_response) as client:
@@ -213,3 +233,97 @@ async def test_asgi_exc_no_raise():
213233
response = await client.get("http://www.example.org/")
214234

215235
assert response.status_code == 500
236+
237+
238+
@pytest.mark.anyio
239+
async def test_asgi_exc_no_raise_after_response_start():
240+
transport = httpx.ASGITransport(
241+
app=raise_exc_after_response_start, raise_app_exceptions=False
242+
)
243+
async with httpx.AsyncClient(transport=transport) as client:
244+
response = await client.get("http://www.example.org/")
245+
246+
assert response.status_code == 200
247+
248+
249+
@pytest.mark.anyio
250+
async def test_asgi_exc_no_raise_after_response():
251+
transport = httpx.ASGITransport(
252+
app=raise_exc_after_response, raise_app_exceptions=False
253+
)
254+
async with httpx.AsyncClient(transport=transport) as client:
255+
response = await client.get("http://www.example.org/")
256+
257+
assert response.status_code == 200
258+
259+
260+
@pytest.mark.anyio
261+
async def test_asgi_stream_returns_before_waiting_for_body():
262+
start_response_body = anyio.Event()
263+
264+
async def send_response_body_after_event(scope, receive, send):
265+
status = 200
266+
headers = [(b"content-type", b"text/plain")]
267+
await send(
268+
{"type": "http.response.start", "status": status, "headers": headers}
269+
)
270+
await start_response_body.wait()
271+
await send({"type": "http.response.body", "body": b"body", "more_body": False})
272+
273+
transport = httpx.ASGITransport(app=send_response_body_after_event)
274+
async with httpx.AsyncClient(transport=transport) as client:
275+
async with client.stream("GET", "http://www.example.org/") as response:
276+
assert response.status_code == 200
277+
start_response_body.set()
278+
await response.aread()
279+
assert response.text == "body"
280+
281+
282+
@pytest.mark.anyio
283+
async def test_asgi_stream_allows_iterative_streaming():
284+
stream_events = [anyio.Event() for i in range(4)]
285+
286+
async def send_response_body_after_event(scope, receive, send):
287+
status = 200
288+
headers = [(b"content-type", b"text/plain")]
289+
await send(
290+
{"type": "http.response.start", "status": status, "headers": headers}
291+
)
292+
for e in stream_events:
293+
await e.wait()
294+
await send(
295+
{
296+
"type": "http.response.body",
297+
"body": b"chunk",
298+
"more_body": e is not stream_events[-1],
299+
}
300+
)
301+
302+
transport = httpx.ASGITransport(app=send_response_body_after_event)
303+
async with httpx.AsyncClient(transport=transport) as client:
304+
async with client.stream("GET", "http://www.example.org/") as response:
305+
assert response.status_code == 200
306+
iterator = response.aiter_raw()
307+
for e in stream_events:
308+
e.set()
309+
assert await iterator.__anext__() == b"chunk"
310+
with pytest.raises(StopAsyncIteration):
311+
await iterator.__anext__()
312+
313+
314+
@pytest.mark.anyio
315+
async def test_asgi_can_be_canceled():
316+
# This test exists to cover transmission of the cancellation exception through
317+
# _AwaitableRunner
318+
app_started = anyio.Event()
319+
320+
async def never_return(scope, receive, send):
321+
app_started.set()
322+
await anyio.sleep_forever()
323+
324+
transport = httpx.ASGITransport(app=never_return)
325+
async with httpx.AsyncClient(transport=transport) as client:
326+
async with anyio.create_task_group() as task_group:
327+
task_group.start_soon(client.get, "http://www.example.org/")
328+
await app_started.wait()
329+
task_group.cancel_scope.cancel()

0 commit comments

Comments
 (0)