99if typing .TYPE_CHECKING : # pragma: no cover
1010 import asyncio
1111
12+ import anyio .abc
13+ import anyio .streams .memory
1214 import trio
1315
1416 Event = typing .Union [asyncio .Event , trio .Event ]
17+ MessageReceiveStream = typing .Union [
18+ anyio .streams .memory .MemoryObjectReceiveStream ["_Message" ],
19+ trio .MemoryReceiveChannel ["_Message" ],
20+ ]
21+ MessageSendStream = typing .Union [
22+ anyio .streams .memory .MemoryObjectSendStream ["_Message" ],
23+ trio .MemorySendChannel ["_Message" ],
24+ ]
25+ TaskGroup = typing .Union [anyio .abc .TaskGroup , trio .Nursery ]
1526
1627
1728_Message = typing .MutableMapping [str , typing .Any ]
@@ -50,12 +61,71 @@ def create_event() -> Event:
5061 return asyncio .Event ()
5162
5263
64+ def create_memory_object_stream (
65+ max_buffer_size : float ,
66+ ) -> tuple [MessageSendStream , MessageReceiveStream ]:
67+ if is_running_trio ():
68+ import trio
69+
70+ return trio .open_memory_channel (max_buffer_size )
71+
72+ import anyio
73+
74+ return anyio .create_memory_object_stream (max_buffer_size )
75+
76+
77+ def create_task_group () -> typing .AsyncContextManager [TaskGroup ]:
78+ if is_running_trio ():
79+ import trio
80+
81+ return trio .open_nursery ()
82+
83+ import anyio
84+
85+ return anyio .create_task_group ()
86+
87+
88+ def get_end_of_stream_error_type () -> type [anyio .EndOfStream | trio .EndOfChannel ]:
89+ if is_running_trio ():
90+ import trio
91+
92+ return trio .EndOfChannel
93+
94+ import anyio
95+
96+ return anyio .EndOfStream
97+
98+
5399class ASGIResponseStream (AsyncByteStream ):
54- def __init__ (self , body : list [bytes ]) -> None :
55- self ._body = body
100+ def __init__ (
101+ self ,
102+ ignore_body : bool ,
103+ asgi_generator : typing .AsyncGenerator [_Message , None ],
104+ disconnect_request_event : Event ,
105+ ) -> None :
106+ self ._ignore_body = ignore_body
107+ self ._asgi_generator = asgi_generator
108+ self ._disconnect_request_event = disconnect_request_event
56109
57110 async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
58- yield b"" .join (self ._body )
111+ more_body = True
112+ try :
113+ async for message in self ._asgi_generator :
114+ assert message ["type" ] != "http.response.start"
115+ if message ["type" ] == "http.response.body" :
116+ assert more_body
117+ chunk = message .get ("body" , b"" )
118+ more_body = message .get ("more_body" , False )
119+ if chunk and not self ._ignore_body :
120+ yield chunk
121+ if not more_body :
122+ self ._disconnect_request_event .set ()
123+ finally :
124+ await self .aclose ()
125+
126+ async def aclose (self ) -> None :
127+ self ._disconnect_request_event .set ()
128+ await self ._asgi_generator .aclose ()
59129
60130
61131class ASGITransport (AsyncBaseTransport ):
@@ -98,6 +168,27 @@ async def handle_async_request(
98168 self ,
99169 request : Request ,
100170 ) -> Response :
171+ disconnect_request_event = create_event ()
172+ asgi_generator = self ._stream_asgi_messages (request , disconnect_request_event )
173+
174+ async for message in asgi_generator :
175+ if message ["type" ] == "http.response.start" :
176+ return Response (
177+ status_code = message ["status" ],
178+ headers = message .get ("headers" , []),
179+ stream = ASGIResponseStream (
180+ ignore_body = request .method == "HEAD" ,
181+ asgi_generator = asgi_generator ,
182+ disconnect_request_event = disconnect_request_event ,
183+ ),
184+ )
185+ else :
186+ disconnect_request_event .set ()
187+ return Response (status_code = 500 , headers = [])
188+
189+ async def _stream_asgi_messages (
190+ self , request : Request , disconnect_request_event : Event
191+ ) -> typing .AsyncGenerator [typing .MutableMapping [str , typing .Any ]]:
101192 assert isinstance (request .stream , AsyncByteStream )
102193
103194 # ASGI scope.
@@ -120,20 +211,21 @@ async def handle_async_request(
120211 request_body_chunks = request .stream .__aiter__ ()
121212 request_complete = False
122213
123- # Response.
124- status_code = None
125- response_headers = None
126- body_parts = []
127- response_started = False
128- response_complete = create_event ()
214+ # ASGI response messages stream
215+ response_message_send_stream , response_message_recv_stream = (
216+ create_memory_object_stream (0 )
217+ )
218+
219+ # ASGI app exception
220+ app_exception : Exception | None = None
129221
130222 # ASGI callables.
131223
132224 async def receive () -> _Message :
133225 nonlocal request_complete
134226
135227 if request_complete :
136- await response_complete .wait ()
228+ await disconnect_request_event .wait ()
137229 return {"type" : "http.disconnect" }
138230
139231 try :
@@ -143,43 +235,25 @@ async def receive() -> _Message:
143235 return {"type" : "http.request" , "body" : b"" , "more_body" : False }
144236 return {"type" : "http.request" , "body" : body , "more_body" : True }
145237
146- async def send (message : _Message ) -> None :
147- nonlocal status_code , response_headers , response_started
148-
149- if message ["type" ] == "http.response.start" :
150- assert not response_started
151-
152- status_code = message ["status" ]
153- response_headers = message .get ("headers" , [])
154- response_started = True
155-
156- elif message ["type" ] == "http.response.body" :
157- assert not response_complete .is_set ()
158- body = message .get ("body" , b"" )
159- more_body = message .get ("more_body" , False )
160-
161- if body and request .method != "HEAD" :
162- body_parts .append (body )
163-
164- if not more_body :
165- response_complete .set ()
166-
167- try :
168- await self .app (scope , receive , send )
169- except Exception : # noqa: PIE-786
170- if self .raise_app_exceptions :
171- raise
172-
173- response_complete .set ()
174- if status_code is None :
175- status_code = 500
176- if response_headers is None :
177- response_headers = {}
178-
179- assert response_complete .is_set ()
180- assert status_code is not None
181- assert response_headers is not None
182-
183- stream = ASGIResponseStream (body_parts )
184-
185- return Response (status_code , headers = response_headers , stream = stream )
238+ async def run_app () -> None :
239+ nonlocal app_exception
240+ try :
241+ await self .app (scope , receive , response_message_send_stream .send )
242+ except Exception as ex :
243+ app_exception = ex
244+ finally :
245+ await response_message_send_stream .aclose ()
246+
247+ async with create_task_group () as task_group :
248+ task_group .start_soon (run_app )
249+
250+ async with response_message_recv_stream :
251+ try :
252+ while True :
253+ message = await response_message_recv_stream .receive ()
254+ yield message
255+ except get_end_of_stream_error_type ():
256+ pass
257+
258+ if app_exception is not None and self .raise_app_exceptions :
259+ raise app_exception
0 commit comments