Skip to content

Commit f9dfd37

Browse files
Add support for streaming in ASGIDispatch
1 parent 8afd29a commit f9dfd37

File tree

2 files changed

+199
-76
lines changed

2 files changed

+199
-76
lines changed

httpx/_transports/asgi.py

Lines changed: 147 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import typing
2+
from contextlib import AsyncExitStack, asynccontextmanager
23

34
import sniffio
45

56
from .._models import Request, Response
67
from .._types import AsyncByteStream
78
from .base import AsyncBaseTransport
89

10+
try:
11+
import anyio
12+
except ImportError: # pragma: no cover
13+
anyio = None # type: ignore
14+
15+
916
if typing.TYPE_CHECKING: # pragma: no cover
1017
import asyncio
1118

@@ -35,12 +42,19 @@ def create_event() -> "Event":
3542
return asyncio.Event()
3643

3744

38-
class ASGIResponseStream(AsyncByteStream):
39-
def __init__(self, body: typing.List[bytes]) -> None:
40-
self._body = body
45+
class ASGIResponseByteStream(AsyncByteStream):
46+
def __init__(
47+
self, stream: typing.AsyncGenerator[bytes, None], app_context: AsyncExitStack
48+
) -> None:
49+
self._stream = stream
50+
self._app_context = app_context
51+
52+
def __aiter__(self) -> typing.AsyncIterator[bytes]:
53+
return self._stream.__aiter__()
4154

42-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
43-
yield b"".join(self._body)
55+
async def aclose(self) -> None:
56+
await self._stream.aclose()
57+
await self._app_context.aclose()
4458

4559

4660
class ASGITransport(AsyncBaseTransport):
@@ -83,6 +97,9 @@ def __init__(
8397
root_path: str = "",
8498
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
8599
) -> None:
100+
if anyio is None:
101+
raise RuntimeError("ASGITransport requires anyio (Hint: pip install anyio)")
102+
86103
self.app = app
87104
self.raise_app_exceptions = raise_app_exceptions
88105
self.root_path = root_path
@@ -92,82 +109,136 @@ async def handle_async_request(
92109
self,
93110
request: Request,
94111
) -> Response:
95-
assert isinstance(request.stream, AsyncByteStream)
96-
97-
# ASGI scope.
98-
scope = {
99-
"type": "http",
100-
"asgi": {"version": "3.0"},
101-
"http_version": "1.1",
102-
"method": request.method,
103-
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
104-
"scheme": request.url.scheme,
105-
"path": request.url.path,
106-
"raw_path": request.url.raw_path,
107-
"query_string": request.url.query,
108-
"server": (request.url.host, request.url.port),
109-
"client": self.client,
110-
"root_path": self.root_path,
111-
}
112-
113-
# Request.
114-
request_body_chunks = request.stream.__aiter__()
115-
request_complete = False
116-
117-
# Response.
118-
status_code = None
119-
response_headers = None
120-
body_parts = []
121-
response_started = False
122-
response_complete = create_event()
123-
124-
# ASGI callables.
125-
126-
async def receive() -> typing.Dict[str, typing.Any]:
127-
nonlocal request_complete
128-
129-
if request_complete:
130-
await response_complete.wait()
131-
return {"type": "http.disconnect"}
132-
133-
try:
134-
body = await request_body_chunks.__anext__()
135-
except StopAsyncIteration:
136-
request_complete = True
137-
return {"type": "http.request", "body": b"", "more_body": False}
138-
return {"type": "http.request", "body": body, "more_body": True}
139-
140-
async def send(message: typing.Dict[str, typing.Any]) -> None:
141-
nonlocal status_code, response_headers, response_started
142-
143-
if message["type"] == "http.response.start":
144-
assert not response_started
145-
146-
status_code = message["status"]
147-
response_headers = message.get("headers", [])
148-
response_started = True
149-
150-
elif message["type"] == "http.response.body":
151-
assert not response_complete.is_set()
152-
body = message.get("body", b"")
153-
more_body = message.get("more_body", False)
154-
155-
if body and request.method != "HEAD":
156-
body_parts.append(body)
157-
158-
if not more_body:
159-
response_complete.set()
160-
112+
exit_stack = AsyncExitStack()
113+
114+
(
115+
status_code,
116+
response_headers,
117+
response_body,
118+
) = await exit_stack.enter_async_context(
119+
run_asgi(
120+
self.app,
121+
raise_app_exceptions=self.raise_app_exceptions,
122+
root_path=self.root_path,
123+
client=self.client,
124+
request=request,
125+
)
126+
)
127+
128+
return Response(
129+
status_code,
130+
headers=response_headers,
131+
stream=ASGIResponseByteStream(response_body, exit_stack),
132+
)
133+
134+
135+
@asynccontextmanager
136+
async def run_asgi(
137+
app: _ASGIApp,
138+
raise_app_exceptions: bool,
139+
client: typing.Tuple[str, int],
140+
root_path: str,
141+
request: Request,
142+
) -> typing.AsyncIterator[
143+
typing.Tuple[
144+
int,
145+
typing.Sequence[typing.Tuple[bytes, bytes]],
146+
typing.AsyncGenerator[bytes, None],
147+
]
148+
]:
149+
# ASGI scope.
150+
scope = {
151+
"type": "http",
152+
"asgi": {"version": "3.0"},
153+
"http_version": "1.1",
154+
"method": request.method,
155+
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
156+
"scheme": request.url.scheme,
157+
"path": request.url.path,
158+
"raw_path": request.url.raw_path,
159+
"query_string": request.url.query,
160+
"server": (request.url.host, request.url.port),
161+
"client": client,
162+
"root_path": root_path,
163+
}
164+
165+
# Request.
166+
assert isinstance(request.stream, AsyncByteStream)
167+
request_body_chunks = request.stream.__aiter__()
168+
request_complete = False
169+
170+
# Response.
171+
status_code = None
172+
response_headers = None
173+
response_started = anyio.Event()
174+
response_complete = anyio.Event()
175+
176+
send_stream, receive_stream = anyio.create_memory_object_stream()
177+
disconnected = anyio.Event()
178+
179+
async def watch_disconnect(cancel_scope: anyio.CancelScope) -> None:
180+
await disconnected.wait()
181+
cancel_scope.cancel()
182+
183+
async def run_app(cancel_scope: anyio.CancelScope) -> None:
161184
try:
162-
await self.app(scope, receive, send)
185+
await app(scope, receive, send)
163186
except Exception: # noqa: PIE-786
164-
if self.raise_app_exceptions or not response_complete.is_set():
187+
if raise_app_exceptions or not response_complete.is_set():
165188
raise
166189

167-
assert response_complete.is_set()
190+
# ASGI callables.
191+
192+
async def receive() -> typing.Dict[str, typing.Any]:
193+
nonlocal request_complete
194+
195+
if request_complete:
196+
await response_complete.wait()
197+
return {"type": "http.disconnect"}
198+
199+
try:
200+
body = await request_body_chunks.__anext__()
201+
except StopAsyncIteration:
202+
request_complete = True
203+
return {"type": "http.request", "body": b"", "more_body": False}
204+
return {"type": "http.request", "body": body, "more_body": True}
205+
206+
async def send(message: _Message) -> None:
207+
nonlocal status_code, response_headers
208+
209+
if disconnected.is_set():
210+
return
211+
212+
if message["type"] == "http.response.start":
213+
assert not response_started.is_set()
214+
215+
status_code = message["status"]
216+
response_headers = message.get("headers", [])
217+
response_started.set()
218+
219+
elif message["type"] == "http.response.body":
220+
assert response_started.is_set()
221+
assert not response_complete.is_set()
222+
body = message.get("body", b"")
223+
more_body = message.get("more_body", False)
224+
225+
if body and request.method != "HEAD":
226+
await send_stream.send(body)
227+
228+
if not more_body:
229+
response_complete.set()
230+
231+
async with anyio.create_task_group() as tg:
232+
tg.start_soon(watch_disconnect, tg.cancel_scope)
233+
tg.start_soon(run_app, tg.cancel_scope)
234+
235+
await response_started.wait()
168236
assert status_code is not None
169237
assert response_headers is not None
170238

171-
stream = ASGIResponseStream(body_parts)
239+
async def stream() -> typing.AsyncGenerator[bytes, None]:
240+
async for chunk in receive_stream:
241+
yield chunk
172242

173-
return Response(status_code, headers=response_headers, stream=stream)
243+
yield (status_code, response_headers, stream())
244+
disconnected.set()

tests/test_asgi.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from contextlib import aclosing
23

34
import pytest
45

@@ -56,6 +57,20 @@ async def echo_headers(scope, receive, send):
5657
await send({"type": "http.response.body", "body": output})
5758

5859

60+
async def hello_world_endlessly(scope, receive, send):
61+
status = 200
62+
output = b"Hello, World!"
63+
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
64+
65+
await send({"type": "http.response.start", "status": status, "headers": headers})
66+
67+
k = 0
68+
while True:
69+
body = b"%d: %s\n" % (k, output)
70+
await send({"type": "http.response.body", "body": body, "more_body": True})
71+
k += 1
72+
73+
5974
async def raise_exc(scope, receive, send):
6075
raise RuntimeError()
6176

@@ -191,3 +206,40 @@ async def read_body(scope, receive, send):
191206

192207
assert response.status_code == 200
193208
assert disconnect
209+
210+
211+
@pytest.mark.anyio
212+
async def test_asgi_streaming():
213+
client = httpx.AsyncClient(app=hello_world_endlessly)
214+
async with client.stream("GET", "http://www.example.org/") as response:
215+
assert response.status_code == 200
216+
lines = []
217+
218+
async with aclosing(response.aiter_lines()) as stream:
219+
async for line in stream:
220+
if line.startswith("3: "):
221+
break
222+
lines.append(line)
223+
224+
assert lines == [
225+
"0: Hello, World!\n",
226+
"1: Hello, World!\n",
227+
"2: Hello, World!\n",
228+
]
229+
230+
231+
@pytest.mark.anyio
232+
async def test_asgi_streaming_exc():
233+
client = httpx.AsyncClient(app=raise_exc)
234+
with pytest.raises(RuntimeError):
235+
async with client.stream("GET", "http://www.example.org/"):
236+
pass # pragma: no cover
237+
238+
239+
@pytest.mark.anyio
240+
async def test_asgi_streaming_exc_after_response():
241+
client = httpx.AsyncClient(app=raise_exc_after_response)
242+
with pytest.raises(RuntimeError):
243+
async with client.stream("GET", "http://www.example.org/") as response:
244+
async for _ in response.aiter_bytes():
245+
pass # pragma: no cover

0 commit comments

Comments
 (0)