Skip to content

Commit cd982c2

Browse files
committed
implement streaming in ASGITransport
1 parent ab842b7 commit cd982c2

File tree

2 files changed

+248
-50
lines changed

2 files changed

+248
-50
lines changed

httpx/_transports/asgi.py

Lines changed: 124 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,20 @@
99
if typing.TYPE_CHECKING: # pragma: no cover
1010
import asyncio
1111

12+
import anyio.abc
13+
import anyio.streams.memory
1214
import trio
1315

1416
Event = typing.Union[asyncio.Event, trio.Event]
17+
MessageReceiveStream = typing.Union[
18+
anyio.streams.memory.MemoryObjectReceiveStream["_Message"],
19+
trio.MemoryReceiveChannel["_Message"],
20+
]
21+
MessageSendStream = typing.Union[
22+
anyio.streams.memory.MemoryObjectSendStream["_Message"],
23+
trio.MemorySendChannel["_Message"],
24+
]
25+
TaskGroup = typing.Union[anyio.abc.TaskGroup, trio.Nursery]
1526

1627

1728
_Message = typing.MutableMapping[str, typing.Any]
@@ -50,12 +61,71 @@ def create_event() -> Event:
5061
return asyncio.Event()
5162

5263

64+
def create_memory_object_stream(
65+
max_buffer_size: float,
66+
) -> tuple[MessageSendStream, MessageReceiveStream]:
67+
if is_running_trio():
68+
import trio
69+
70+
return trio.open_memory_channel(max_buffer_size)
71+
72+
import anyio
73+
74+
return anyio.create_memory_object_stream(max_buffer_size)
75+
76+
77+
def create_task_group() -> typing.AsyncContextManager[TaskGroup]:
78+
if is_running_trio():
79+
import trio
80+
81+
return trio.open_nursery()
82+
83+
import anyio
84+
85+
return anyio.create_task_group()
86+
87+
88+
def get_end_of_stream_error_type() -> type[anyio.EndOfStream | trio.EndOfChannel]:
89+
if is_running_trio():
90+
import trio
91+
92+
return trio.EndOfChannel
93+
94+
import anyio
95+
96+
return anyio.EndOfStream
97+
98+
5399
class ASGIResponseStream(AsyncByteStream):
54-
def __init__(self, body: list[bytes]) -> None:
55-
self._body = body
100+
def __init__(
101+
self,
102+
ignore_body: bool,
103+
asgi_generator: typing.AsyncGenerator[_Message, None],
104+
disconnect_request_event: Event,
105+
) -> None:
106+
self._ignore_body = ignore_body
107+
self._asgi_generator = asgi_generator
108+
self._disconnect_request_event = disconnect_request_event
56109

57110
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
58-
yield b"".join(self._body)
111+
more_body = True
112+
try:
113+
async for message in self._asgi_generator:
114+
assert message["type"] != "http.response.start"
115+
if message["type"] == "http.response.body":
116+
assert more_body
117+
chunk = message.get("body", b"")
118+
more_body = message.get("more_body", False)
119+
if chunk and not self._ignore_body:
120+
yield chunk
121+
if not more_body:
122+
self._disconnect_request_event.set()
123+
finally:
124+
await self.aclose()
125+
126+
async def aclose(self) -> None:
127+
self._disconnect_request_event.set()
128+
await self._asgi_generator.aclose()
59129

60130

61131
class ASGITransport(AsyncBaseTransport):
@@ -98,6 +168,27 @@ async def handle_async_request(
98168
self,
99169
request: Request,
100170
) -> Response:
171+
disconnect_request_event = create_event()
172+
asgi_generator = self._stream_asgi_messages(request, disconnect_request_event)
173+
174+
async for message in asgi_generator:
175+
if message["type"] == "http.response.start":
176+
return Response(
177+
status_code=message["status"],
178+
headers=message.get("headers", []),
179+
stream=ASGIResponseStream(
180+
ignore_body=request.method == "HEAD",
181+
asgi_generator=asgi_generator,
182+
disconnect_request_event=disconnect_request_event,
183+
),
184+
)
185+
else:
186+
disconnect_request_event.set()
187+
return Response(status_code=500, headers=[])
188+
189+
async def _stream_asgi_messages(
190+
self, request: Request, disconnect_request_event: Event
191+
) -> typing.AsyncGenerator[typing.MutableMapping[str, typing.Any]]:
101192
assert isinstance(request.stream, AsyncByteStream)
102193

103194
# ASGI scope.
@@ -120,20 +211,21 @@ async def handle_async_request(
120211
request_body_chunks = request.stream.__aiter__()
121212
request_complete = False
122213

123-
# Response.
124-
status_code = None
125-
response_headers = None
126-
body_parts = []
127-
response_started = False
128-
response_complete = create_event()
214+
# ASGI response messages stream
215+
response_message_send_stream, response_message_recv_stream = (
216+
create_memory_object_stream(0)
217+
)
218+
219+
# ASGI app exception
220+
app_exception: Exception | None = None
129221

130222
# ASGI callables.
131223

132224
async def receive() -> _Message:
133225
nonlocal request_complete
134226

135227
if request_complete:
136-
await response_complete.wait()
228+
await disconnect_request_event.wait()
137229
return {"type": "http.disconnect"}
138230

139231
try:
@@ -143,43 +235,25 @@ async def receive() -> _Message:
143235
return {"type": "http.request", "body": b"", "more_body": False}
144236
return {"type": "http.request", "body": body, "more_body": True}
145237

146-
async def send(message: _Message) -> None:
147-
nonlocal status_code, response_headers, response_started
148-
149-
if message["type"] == "http.response.start":
150-
assert not response_started
151-
152-
status_code = message["status"]
153-
response_headers = message.get("headers", [])
154-
response_started = True
155-
156-
elif message["type"] == "http.response.body":
157-
assert not response_complete.is_set()
158-
body = message.get("body", b"")
159-
more_body = message.get("more_body", False)
160-
161-
if body and request.method != "HEAD":
162-
body_parts.append(body)
163-
164-
if not more_body:
165-
response_complete.set()
166-
167-
try:
168-
await self.app(scope, receive, send)
169-
except Exception: # noqa: PIE-786
170-
if self.raise_app_exceptions:
171-
raise
172-
173-
response_complete.set()
174-
if status_code is None:
175-
status_code = 500
176-
if response_headers is None:
177-
response_headers = {}
178-
179-
assert response_complete.is_set()
180-
assert status_code is not None
181-
assert response_headers is not None
182-
183-
stream = ASGIResponseStream(body_parts)
184-
185-
return Response(status_code, headers=response_headers, stream=stream)
238+
async def run_app() -> None:
239+
nonlocal app_exception
240+
try:
241+
await self.app(scope, receive, response_message_send_stream.send)
242+
except Exception as ex:
243+
app_exception = ex
244+
finally:
245+
await response_message_send_stream.aclose()
246+
247+
async with create_task_group() as task_group:
248+
task_group.start_soon(run_app)
249+
250+
async with response_message_recv_stream:
251+
try:
252+
while True:
253+
message = await response_message_recv_stream.receive()
254+
yield message
255+
except get_end_of_stream_error_type():
256+
pass
257+
258+
if app_exception is not None and self.raise_app_exceptions:
259+
raise app_exception

tests/test_asgi.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
import json
22

3+
import anyio
34
import pytest
45

56
import httpx
67

78

9+
def run_in_task_group(app):
10+
"""A decorator that runs an ASGI callable in a task group"""
11+
12+
async def wrapped_app(*args):
13+
async with anyio.create_task_group() as task_group:
14+
task_group.start_soon(app, *args)
15+
16+
return wrapped_app
17+
18+
819
async def hello_world(scope, receive, send):
920
status = 200
1021
output = b"Hello, World!"
@@ -60,6 +71,15 @@ async def raise_exc(scope, receive, send):
6071
raise RuntimeError()
6172

6273

74+
async def raise_exc_after_response_start(scope, receive, send):
75+
status = 200
76+
output = b"Hello, World!"
77+
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
78+
79+
await send({"type": "http.response.start", "status": status, "headers": headers})
80+
raise RuntimeError()
81+
82+
6383
async def raise_exc_after_response(scope, receive, send):
6484
status = 200
6585
output = b"Hello, World!"
@@ -172,6 +192,14 @@ async def test_asgi_exc():
172192
await client.get("http://www.example.org/")
173193

174194

195+
@pytest.mark.anyio
196+
async def test_asgi_exc_after_response_start():
197+
transport = httpx.ASGITransport(app=raise_exc_after_response_start)
198+
async with httpx.AsyncClient(transport=transport) as client:
199+
with pytest.raises(RuntimeError):
200+
await client.get("http://www.example.org/")
201+
202+
175203
@pytest.mark.anyio
176204
async def test_asgi_exc_after_response():
177205
transport = httpx.ASGITransport(app=raise_exc_after_response)
@@ -222,3 +250,99 @@ async def test_asgi_exc_no_raise():
222250
response = await client.get("http://www.example.org/")
223251

224252
assert response.status_code == 500
253+
254+
255+
@pytest.mark.anyio
256+
async def test_asgi_exc_no_raise_after_response_start():
257+
transport = httpx.ASGITransport(
258+
app=raise_exc_after_response_start, raise_app_exceptions=False
259+
)
260+
async with httpx.AsyncClient(transport=transport) as client:
261+
response = await client.get("http://www.example.org/")
262+
263+
assert response.status_code == 200
264+
265+
266+
@pytest.mark.anyio
267+
async def test_asgi_exc_no_raise_after_response():
268+
transport = httpx.ASGITransport(
269+
app=raise_exc_after_response, raise_app_exceptions=False
270+
)
271+
async with httpx.AsyncClient(transport=transport) as client:
272+
response = await client.get("http://www.example.org/")
273+
274+
assert response.status_code == 200
275+
276+
277+
@pytest.mark.parametrize(
278+
"send_in_sub_task",
279+
[pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")],
280+
)
281+
@pytest.mark.anyio
282+
async def test_asgi_stream_returns_before_waiting_for_body(send_in_sub_task: bool):
283+
start_response_body = anyio.Event()
284+
285+
async def send_response_body_after_event(scope, receive, send):
286+
status = 200
287+
headers = [(b"content-type", b"text/plain")]
288+
await send(
289+
{"type": "http.response.start", "status": status, "headers": headers}
290+
)
291+
await start_response_body.wait()
292+
await send({"type": "http.response.body", "body": b"body", "more_body": False})
293+
294+
if send_in_sub_task:
295+
send_response_body_after_event = run_in_task_group(
296+
send_response_body_after_event
297+
)
298+
299+
transport = httpx.ASGITransport(app=send_response_body_after_event)
300+
async with httpx.AsyncClient(transport=transport) as client:
301+
with anyio.fail_after(0.1):
302+
async with client.stream("GET", "http://www.example.org/") as response:
303+
assert response.status_code == 200
304+
start_response_body.set()
305+
await response.aread()
306+
assert response.text == "body"
307+
308+
309+
@pytest.mark.parametrize(
310+
"send_in_sub_task",
311+
[pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")],
312+
)
313+
@pytest.mark.anyio
314+
async def test_asgi_stream_allows_iterative_streaming(send_in_sub_task: bool):
315+
stream_events = [anyio.Event() for i in range(4)]
316+
317+
async def send_response_body_after_event(scope, receive, send):
318+
status = 200
319+
headers = [(b"content-type", b"text/plain")]
320+
await send(
321+
{"type": "http.response.start", "status": status, "headers": headers}
322+
)
323+
for e in stream_events:
324+
await e.wait()
325+
await send(
326+
{
327+
"type": "http.response.body",
328+
"body": b"chunk",
329+
"more_body": e is not stream_events[-1],
330+
}
331+
)
332+
333+
if send_in_sub_task:
334+
send_response_body_after_event = run_in_task_group(
335+
send_response_body_after_event
336+
)
337+
338+
transport = httpx.ASGITransport(app=send_response_body_after_event)
339+
async with httpx.AsyncClient(transport=transport) as client:
340+
with anyio.fail_after(0.1):
341+
async with client.stream("GET", "http://www.example.org/") as response:
342+
assert response.status_code == 200
343+
iterator = response.aiter_raw()
344+
for e in stream_events:
345+
e.set()
346+
assert await iterator.__anext__() == b"chunk"
347+
with pytest.raises(StopAsyncIteration):
348+
await iterator.__anext__()

0 commit comments

Comments
 (0)