11import typing
2+ from contextlib import AsyncExitStack , asynccontextmanager
23
34import sniffio
45
56from .._models import Request , Response
67from .._types import AsyncByteStream
78from .base import AsyncBaseTransport
89
10+ try :
11+ import anyio
12+ except ImportError : # pragma: no cover
13+ anyio = None # type: ignore
14+
15+
916if typing .TYPE_CHECKING : # pragma: no cover
1017 import asyncio
1118
@@ -35,12 +42,19 @@ def create_event() -> "Event":
3542 return asyncio .Event ()
3643
3744
38- class ASGIResponseStream (AsyncByteStream ):
39- def __init__ (self , body : typing .List [bytes ]) -> None :
40- self ._body = body
45+ class ASGIResponseByteStream (AsyncByteStream ):
46+ def __init__ (
47+ self , stream : typing .AsyncGenerator [bytes , None ], app_context : AsyncExitStack
48+ ) -> None :
49+ self ._stream = stream
50+ self ._app_context = app_context
51+
52+ def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
53+ return self ._stream .__aiter__ ()
4154
42- async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
43- yield b"" .join (self ._body )
55+ async def aclose (self ) -> None :
56+ await self ._stream .aclose ()
57+ await self ._app_context .aclose ()
4458
4559
4660class ASGITransport (AsyncBaseTransport ):
@@ -83,6 +97,9 @@ def __init__(
8397 root_path : str = "" ,
8498 client : typing .Tuple [str , int ] = ("127.0.0.1" , 123 ),
8599 ) -> None :
100+ if anyio is None :
101+ raise RuntimeError ("ASGITransport requires anyio (Hint: pip install anyio)" )
102+
86103 self .app = app
87104 self .raise_app_exceptions = raise_app_exceptions
88105 self .root_path = root_path
@@ -92,82 +109,136 @@ async def handle_async_request(
92109 self ,
93110 request : Request ,
94111 ) -> Response :
95- assert isinstance (request .stream , AsyncByteStream )
96-
97- # ASGI scope.
98- scope = {
99- "type" : "http" ,
100- "asgi" : {"version" : "3.0" },
101- "http_version" : "1.1" ,
102- "method" : request .method ,
103- "headers" : [(k .lower (), v ) for (k , v ) in request .headers .raw ],
104- "scheme" : request .url .scheme ,
105- "path" : request .url .path ,
106- "raw_path" : request .url .raw_path ,
107- "query_string" : request .url .query ,
108- "server" : (request .url .host , request .url .port ),
109- "client" : self .client ,
110- "root_path" : self .root_path ,
111- }
112-
113- # Request.
114- request_body_chunks = request .stream .__aiter__ ()
115- request_complete = False
116-
117- # Response.
118- status_code = None
119- response_headers = None
120- body_parts = []
121- response_started = False
122- response_complete = create_event ()
123-
124- # ASGI callables.
125-
126- async def receive () -> typing .Dict [str , typing .Any ]:
127- nonlocal request_complete
128-
129- if request_complete :
130- await response_complete .wait ()
131- return {"type" : "http.disconnect" }
132-
133- try :
134- body = await request_body_chunks .__anext__ ()
135- except StopAsyncIteration :
136- request_complete = True
137- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
138- return {"type" : "http.request" , "body" : body , "more_body" : True }
139-
140- async def send (message : typing .Dict [str , typing .Any ]) -> None :
141- nonlocal status_code , response_headers , response_started
142-
143- if message ["type" ] == "http.response.start" :
144- assert not response_started
145-
146- status_code = message ["status" ]
147- response_headers = message .get ("headers" , [])
148- response_started = True
149-
150- elif message ["type" ] == "http.response.body" :
151- assert not response_complete .is_set ()
152- body = message .get ("body" , b"" )
153- more_body = message .get ("more_body" , False )
154-
155- if body and request .method != "HEAD" :
156- body_parts .append (body )
157-
158- if not more_body :
159- response_complete .set ()
160-
112+ exit_stack = AsyncExitStack ()
113+
114+ (
115+ status_code ,
116+ response_headers ,
117+ response_body ,
118+ ) = await exit_stack .enter_async_context (
119+ run_asgi (
120+ self .app ,
121+ raise_app_exceptions = self .raise_app_exceptions ,
122+ root_path = self .root_path ,
123+ client = self .client ,
124+ request = request ,
125+ )
126+ )
127+
128+ return Response (
129+ status_code ,
130+ headers = response_headers ,
131+ stream = ASGIResponseByteStream (response_body , exit_stack ),
132+ )
133+
134+
135+ @asynccontextmanager
136+ async def run_asgi (
137+ app : _ASGIApp ,
138+ raise_app_exceptions : bool ,
139+ client : typing .Tuple [str , int ],
140+ root_path : str ,
141+ request : Request ,
142+ ) -> typing .AsyncIterator [
143+ typing .Tuple [
144+ int ,
145+ typing .Sequence [typing .Tuple [bytes , bytes ]],
146+ typing .AsyncGenerator [bytes , None ],
147+ ]
148+ ]:
149+ # ASGI scope.
150+ scope = {
151+ "type" : "http" ,
152+ "asgi" : {"version" : "3.0" },
153+ "http_version" : "1.1" ,
154+ "method" : request .method ,
155+ "headers" : [(k .lower (), v ) for (k , v ) in request .headers .raw ],
156+ "scheme" : request .url .scheme ,
157+ "path" : request .url .path ,
158+ "raw_path" : request .url .raw_path ,
159+ "query_string" : request .url .query ,
160+ "server" : (request .url .host , request .url .port ),
161+ "client" : client ,
162+ "root_path" : root_path ,
163+ }
164+
165+ # Request.
166+ assert isinstance (request .stream , AsyncByteStream )
167+ request_body_chunks = request .stream .__aiter__ ()
168+ request_complete = False
169+
170+ # Response.
171+ status_code = None
172+ response_headers = None
173+ response_started = anyio .Event ()
174+ response_complete = anyio .Event ()
175+
176+ send_stream , receive_stream = anyio .create_memory_object_stream ()
177+ disconnected = anyio .Event ()
178+
179+ async def watch_disconnect (cancel_scope : anyio .CancelScope ) -> None :
180+ await disconnected .wait ()
181+ cancel_scope .cancel ()
182+
183+ async def run_app (cancel_scope : anyio .CancelScope ) -> None :
161184 try :
162- await self . app (scope , receive , send )
185+ await app (scope , receive , send )
163186 except Exception : # noqa: PIE-786
164- if self . raise_app_exceptions or not response_complete .is_set ():
187+ if raise_app_exceptions or not response_complete .is_set ():
165188 raise
166189
167- assert response_complete .is_set ()
190+ # ASGI callables.
191+
192+ async def receive () -> typing .Dict [str , typing .Any ]:
193+ nonlocal request_complete
194+
195+ if request_complete :
196+ await response_complete .wait ()
197+ return {"type" : "http.disconnect" }
198+
199+ try :
200+ body = await request_body_chunks .__anext__ ()
201+ except StopAsyncIteration :
202+ request_complete = True
203+ return {"type" : "http.request" , "body" : b"" , "more_body" : False }
204+ return {"type" : "http.request" , "body" : body , "more_body" : True }
205+
206+ async def send (message : _Message ) -> None :
207+ nonlocal status_code , response_headers
208+
209+ if disconnected .is_set ():
210+ return
211+
212+ if message ["type" ] == "http.response.start" :
213+ assert not response_started .is_set ()
214+
215+ status_code = message ["status" ]
216+ response_headers = message .get ("headers" , [])
217+ response_started .set ()
218+
219+ elif message ["type" ] == "http.response.body" :
220+ assert response_started .is_set ()
221+ assert not response_complete .is_set ()
222+ body = message .get ("body" , b"" )
223+ more_body = message .get ("more_body" , False )
224+
225+ if body and request .method != "HEAD" :
226+ await send_stream .send (body )
227+
228+ if not more_body :
229+ response_complete .set ()
230+
231+ async with anyio .create_task_group () as tg :
232+ tg .start_soon (watch_disconnect , tg .cancel_scope )
233+ tg .start_soon (run_app , tg .cancel_scope )
234+
235+ await response_started .wait ()
168236 assert status_code is not None
169237 assert response_headers is not None
170238
171- stream = ASGIResponseStream (body_parts )
239+ async def stream () -> typing .AsyncGenerator [bytes , None ]:
240+ async for chunk in receive_stream :
241+ yield chunk
172242
173- return Response (status_code , headers = response_headers , stream = stream )
243+ yield (status_code , response_headers , stream ())
244+ disconnected .set ()
0 commit comments