Skip to content

Commit e70c821

Browse files
committed
http2: support trailer headers
1 parent 9820975 commit e70c821

File tree

4 files changed

+412
-18
lines changed

4 files changed

+412
-18
lines changed

httpcore/_async/http2.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
155165
return Response(
156166
status=status,
157167
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
168+
content=HTTP2ConnectionByteStream(
169+
connection=self,
170+
request=request,
171+
stream_id=stream_id,
172+
extensions=extensions,
173+
),
174+
extensions=extensions,
164175
)
165176
except BaseException as exc: # noqa: PIE786
166177
with AsyncShieldCancellation():
@@ -321,12 +332,21 @@ async def _receive_response_body(
321332
self._h2_state.acknowledge_received_data(amount, stream_id)
322333
await self._write_outgoing_data(request)
323334
yield event.data
335+
elif isinstance(event, h2.events.TrailersReceived):
336+
# Process trailing headers but continue receiving events
337+
# The trailing headers are already stored in self._trailing_headers
338+
continue
324339
elif isinstance(event, h2.events.StreamEnded):
325340
break
326341

327342
async def _receive_stream_event(
328343
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
344+
) -> (
345+
h2.events.ResponseReceived
346+
| h2.events.DataReceived
347+
| h2.events.StreamEnded
348+
| h2.events.TrailersReceived
349+
):
330350
"""
331351
Return the next available event for a given stream ID.
332352
@@ -377,10 +397,19 @@ async def _receive_events(
377397
h2.events.DataReceived,
378398
h2.events.StreamEnded,
379399
h2.events.StreamReset,
400+
h2.events.TrailersReceived,
380401
),
381402
):
382403
if event.stream_id in self._events:
383404
self._events[event.stream_id].append(event)
405+
if isinstance(event, h2.events.TrailersReceived):
406+
self._trailing_headers[event.stream_id] = []
407+
if event.headers is not None:
408+
for k, v in event.headers:
409+
if not k.startswith(b":"):
410+
self._trailing_headers[
411+
event.stream_id
412+
].append((k, v))
384413

385414
elif isinstance(event, h2.events.ConnectionTerminated):
386415
self._connection_terminated = event
@@ -409,6 +438,8 @@ async def _receive_remote_settings_change(
409438
async def _response_closed(self, stream_id: int) -> None:
410439
await self._max_streams_semaphore.release()
411440
del self._events[stream_id]
441+
if stream_id in self._trailing_headers:
442+
del self._trailing_headers[stream_id]
412443
async with self._state_lock:
413444
if self._connection_terminated and not self._events:
414445
await self.aclose()
@@ -561,12 +592,17 @@ async def __aexit__(
561592

562593
class HTTP2ConnectionByteStream:
563594
def __init__(
564-
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
595+
self,
596+
connection: AsyncHTTP2Connection,
597+
request: Request,
598+
stream_id: int,
599+
extensions: typing.MutableMapping[str, typing.Any],
565600
) -> None:
566601
self._connection = connection
567602
self._request = request
568603
self._stream_id = stream_id
569604
self._closed = False
605+
self._extensions = extensions
570606

571607
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
572608
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +612,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576612
request=self._request, stream_id=self._stream_id
577613
):
578614
yield chunk
615+
616+
if self._stream_id in self._connection._trailing_headers:
617+
self._extensions["trailing_headers"] = (
618+
self._connection._trailing_headers[self._stream_id]
619+
)
579620
except BaseException as exc:
580621
# If we get an exception while streaming the response,
581622
# we want to close the response (and possibly the connection)

httpcore/_sync/http2.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ def handle_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
155165
return Response(
156166
status=status,
157167
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
168+
content=HTTP2ConnectionByteStream(
169+
connection=self,
170+
request=request,
171+
stream_id=stream_id,
172+
extensions=extensions,
173+
),
174+
extensions=extensions,
164175
)
165176
except BaseException as exc: # noqa: PIE786
166177
with ShieldCancellation():
@@ -321,12 +332,21 @@ def _receive_response_body(
321332
self._h2_state.acknowledge_received_data(amount, stream_id)
322333
self._write_outgoing_data(request)
323334
yield event.data
335+
elif isinstance(event, h2.events.TrailersReceived):
336+
# Process trailing headers but continue receiving events
337+
# The trailing headers are already stored in self._trailing_headers
338+
continue
324339
elif isinstance(event, h2.events.StreamEnded):
325340
break
326341

327342
def _receive_stream_event(
328343
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
344+
) -> (
345+
h2.events.ResponseReceived
346+
| h2.events.DataReceived
347+
| h2.events.StreamEnded
348+
| h2.events.TrailersReceived
349+
):
330350
"""
331351
Return the next available event for a given stream ID.
332352
@@ -377,10 +397,19 @@ def _receive_events(
377397
h2.events.DataReceived,
378398
h2.events.StreamEnded,
379399
h2.events.StreamReset,
400+
h2.events.TrailersReceived,
380401
),
381402
):
382403
if event.stream_id in self._events:
383404
self._events[event.stream_id].append(event)
405+
if isinstance(event, h2.events.TrailersReceived):
406+
self._trailing_headers[event.stream_id] = []
407+
if event.headers is not None:
408+
for k, v in event.headers:
409+
if not k.startswith(b":"):
410+
self._trailing_headers[
411+
event.stream_id
412+
].append((k, v))
384413

385414
elif isinstance(event, h2.events.ConnectionTerminated):
386415
self._connection_terminated = event
@@ -409,6 +438,8 @@ def _receive_remote_settings_change(
409438
def _response_closed(self, stream_id: int) -> None:
410439
self._max_streams_semaphore.release()
411440
del self._events[stream_id]
441+
if stream_id in self._trailing_headers:
442+
del self._trailing_headers[stream_id]
412443
with self._state_lock:
413444
if self._connection_terminated and not self._events:
414445
self.close()
@@ -561,12 +592,17 @@ def __exit__(
561592

562593
class HTTP2ConnectionByteStream:
563594
def __init__(
564-
self, connection: HTTP2Connection, request: Request, stream_id: int
595+
self,
596+
connection: HTTP2Connection,
597+
request: Request,
598+
stream_id: int,
599+
extensions: typing.MutableMapping[str, typing.Any],
565600
) -> None:
566601
self._connection = connection
567602
self._request = request
568603
self._stream_id = stream_id
569604
self._closed = False
605+
self._extensions = extensions
570606

571607
def __iter__(self) -> typing.Iterator[bytes]:
572608
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +612,11 @@ def __iter__(self) -> typing.Iterator[bytes]:
576612
request=self._request, stream_id=self._stream_id
577613
):
578614
yield chunk
615+
616+
if self._stream_id in self._connection._trailing_headers:
617+
self._extensions["trailing_headers"] = (
618+
self._connection._trailing_headers[self._stream_id]
619+
)
579620
except BaseException as exc:
580621
# If we get an exception while streaming the response,
581622
# we want to close the response (and possibly the connection)

0 commit comments

Comments
 (0)