@@ -71,10 +71,14 @@ def __init__(
71
71
h2 .events .ResponseReceived
72
72
| h2 .events .DataReceived
73
73
| h2 .events .StreamEnded
74
- | h2 .events .StreamReset ,
74
+ | h2 .events .StreamReset
75
+ | h2 .events .TrailersReceived ,
75
76
],
76
77
] = {}
77
78
79
+ # Mapping from stream ID to trailing headers
80
+ self ._trailing_headers : dict [int , list [tuple [bytes , bytes ]]] = {}
81
+
78
82
# Connection terminated events are stored as state since
79
83
# we need to handle them for all streams.
80
84
self ._connection_terminated : h2 .events .ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response:
152
156
)
153
157
trace .return_value = (status , headers )
154
158
159
+ extensions = {
160
+ "http_version" : b"HTTP/2" ,
161
+ "network_stream" : self ._network_stream ,
162
+ "stream_id" : stream_id ,
163
+ }
164
+
155
165
return Response (
156
166
status = status ,
157
167
headers = headers ,
158
- content = HTTP2ConnectionByteStream (self , request , stream_id = stream_id ),
159
- extensions = {
160
- "http_version" : b"HTTP/2" ,
161
- "network_stream" : self ._network_stream ,
162
- "stream_id" : stream_id ,
163
- },
168
+ content = HTTP2ConnectionByteStream (
169
+ connection = self ,
170
+ request = request ,
171
+ stream_id = stream_id ,
172
+ extensions = extensions ,
173
+ ),
174
+ extensions = extensions ,
164
175
)
165
176
except BaseException as exc : # noqa: PIE786
166
177
with AsyncShieldCancellation ():
@@ -321,12 +332,21 @@ async def _receive_response_body(
321
332
self ._h2_state .acknowledge_received_data (amount , stream_id )
322
333
await self ._write_outgoing_data (request )
323
334
yield event .data
335
+ elif isinstance (event , h2 .events .TrailersReceived ):
336
+ # Process trailing headers but continue receiving events
337
+ # The trailing headers are already stored in self._trailing_headers
338
+ continue
324
339
elif isinstance (event , h2 .events .StreamEnded ):
325
340
break
326
341
327
342
async def _receive_stream_event (
328
343
self , request : Request , stream_id : int
329
- ) -> h2 .events .ResponseReceived | h2 .events .DataReceived | h2 .events .StreamEnded :
344
+ ) -> (
345
+ h2 .events .ResponseReceived
346
+ | h2 .events .DataReceived
347
+ | h2 .events .StreamEnded
348
+ | h2 .events .TrailersReceived
349
+ ):
330
350
"""
331
351
Return the next available event for a given stream ID.
332
352
@@ -377,10 +397,19 @@ async def _receive_events(
377
397
h2 .events .DataReceived ,
378
398
h2 .events .StreamEnded ,
379
399
h2 .events .StreamReset ,
400
+ h2 .events .TrailersReceived ,
380
401
),
381
402
):
382
403
if event .stream_id in self ._events :
383
404
self ._events [event .stream_id ].append (event )
405
+ if isinstance (event , h2 .events .TrailersReceived ):
406
+ self ._trailing_headers [event .stream_id ] = []
407
+ if event .headers is not None :
408
+ for k , v in event .headers :
409
+ if not k .startswith (b":" ):
410
+ self ._trailing_headers [
411
+ event .stream_id
412
+ ].append ((k , v ))
384
413
385
414
elif isinstance (event , h2 .events .ConnectionTerminated ):
386
415
self ._connection_terminated = event
@@ -409,6 +438,8 @@ async def _receive_remote_settings_change(
409
438
async def _response_closed (self , stream_id : int ) -> None :
410
439
await self ._max_streams_semaphore .release ()
411
440
del self ._events [stream_id ]
441
+ if stream_id in self ._trailing_headers :
442
+ del self ._trailing_headers [stream_id ]
412
443
async with self ._state_lock :
413
444
if self ._connection_terminated and not self ._events :
414
445
await self .aclose ()
@@ -561,12 +592,17 @@ async def __aexit__(
561
592
562
593
class HTTP2ConnectionByteStream :
563
594
def __init__ (
564
- self , connection : AsyncHTTP2Connection , request : Request , stream_id : int
595
+ self ,
596
+ connection : AsyncHTTP2Connection ,
597
+ request : Request ,
598
+ stream_id : int ,
599
+ extensions : typing .MutableMapping [str , typing .Any ],
565
600
) -> None :
566
601
self ._connection = connection
567
602
self ._request = request
568
603
self ._stream_id = stream_id
569
604
self ._closed = False
605
+ self ._extensions = extensions
570
606
571
607
async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
572
608
kwargs = {"request" : self ._request , "stream_id" : self ._stream_id }
@@ -576,6 +612,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576
612
request = self ._request , stream_id = self ._stream_id
577
613
):
578
614
yield chunk
615
+
616
+ if self ._stream_id in self ._connection ._trailing_headers :
617
+ self ._extensions ["trailing_headers" ] = (
618
+ self ._connection ._trailing_headers [self ._stream_id ]
619
+ )
579
620
except BaseException as exc :
580
621
# If we get an exception while streaming the response,
581
622
# we want to close the response (and possibly the connection)
0 commit comments