Skip to content

Commit accae7b

Browse files
MtkN1lovelydinosaurT-256Tom Christie
authored
Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882)
* Add a starting point for the work * Add draft tests * Support connection `Upgrade` and `CONNECT`. * Update CHANGELOG.md * Remove private state assertions * Add Async prefix * Update CHANGELOG.md Co-authored-by: Tom Christie <[email protected]> * Update tests/_async/test_http11.py Co-authored-by: T-256 <[email protected]> --------- Co-authored-by: Tom Christie <[email protected]> Co-authored-by: T-256 <[email protected]> Co-authored-by: Tom Christie <[email protected]>
1 parent c468024 commit accae7b

File tree

5 files changed

+200
-6
lines changed

5 files changed

+200
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

7+
## Unreleased
8+
9+
- Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882)
10+
711
## 1.0.3 (February 13th, 2024)
812

913
- Fix support for async cancellations. (#880)

httpcore/_async/http11.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import enum
22
import logging
3+
import ssl
34
import time
45
from types import TracebackType
56
from typing import (
7+
Any,
68
AsyncIterable,
79
AsyncIterator,
810
List,
@@ -107,6 +109,7 @@ async def handle_async_request(self, request: Request) -> Response:
107109
status,
108110
reason_phrase,
109111
headers,
112+
trailing_data,
110113
) = await self._receive_response_headers(**kwargs)
111114
trace.return_value = (
112115
http_version,
@@ -115,14 +118,22 @@ async def handle_async_request(self, request: Request) -> Response:
115118
headers,
116119
)
117120

121+
network_stream = self._network_stream
122+
123+
# CONNECT or Upgrade request
124+
if (status == 101) or (
125+
(request.method == b"CONNECT") and (200 <= status < 300)
126+
):
127+
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)
128+
118129
return Response(
119130
status=status,
120131
headers=headers,
121132
content=HTTP11ConnectionByteStream(self, request),
122133
extensions={
123134
"http_version": http_version,
124135
"reason_phrase": reason_phrase,
125-
"network_stream": self._network_stream,
136+
"network_stream": network_stream,
126137
},
127138
)
128139
except BaseException as exc:
@@ -167,7 +178,7 @@ async def _send_event(
167178

168179
async def _receive_response_headers(
169180
self, request: Request
170-
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
181+
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
171182
timeouts = request.extensions.get("timeout", {})
172183
timeout = timeouts.get("read", None)
173184

@@ -187,7 +198,9 @@ async def _receive_response_headers(
187198
# raw header casing, rather than the enforced lowercase headers.
188199
headers = event.headers.raw_items()
189200

190-
return http_version, event.status_code, event.reason, headers
201+
trailing_data, _ = self._h11_state.trailing_data
202+
203+
return http_version, event.status_code, event.reason, headers, trailing_data
191204

192205
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
193206
timeouts = request.extensions.get("timeout", {})
@@ -340,3 +353,34 @@ async def aclose(self) -> None:
340353
self._closed = True
341354
async with Trace("response_closed", logger, self._request):
342355
await self._connection._response_closed()
356+
357+
358+
class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
359+
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
360+
self._stream = stream
361+
self._leading_data = leading_data
362+
363+
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
364+
if self._leading_data:
365+
buffer = self._leading_data[:max_bytes]
366+
self._leading_data = self._leading_data[max_bytes:]
367+
return buffer
368+
else:
369+
return await self._stream.read(max_bytes, timeout)
370+
371+
async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
372+
await self._stream.write(buffer, timeout)
373+
374+
async def aclose(self) -> None:
375+
await self._stream.aclose()
376+
377+
async def start_tls(
378+
self,
379+
ssl_context: ssl.SSLContext,
380+
server_hostname: Optional[str] = None,
381+
timeout: Optional[float] = None,
382+
) -> AsyncNetworkStream:
383+
return await self._stream.start_tls(ssl_context, server_hostname, timeout)
384+
385+
def get_extra_info(self, info: str) -> Any:
386+
return self._stream.get_extra_info(info)

httpcore/_sync/http11.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import enum
22
import logging
3+
import ssl
34
import time
45
from types import TracebackType
56
from typing import (
7+
Any,
68
Iterable,
79
Iterator,
810
List,
@@ -107,6 +109,7 @@ def handle_request(self, request: Request) -> Response:
107109
status,
108110
reason_phrase,
109111
headers,
112+
trailing_data,
110113
) = self._receive_response_headers(**kwargs)
111114
trace.return_value = (
112115
http_version,
@@ -115,14 +118,22 @@ def handle_request(self, request: Request) -> Response:
115118
headers,
116119
)
117120

121+
network_stream = self._network_stream
122+
123+
# CONNECT or Upgrade request
124+
if (status == 101) or (
125+
(request.method == b"CONNECT") and (200 <= status < 300)
126+
):
127+
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)
128+
118129
return Response(
119130
status=status,
120131
headers=headers,
121132
content=HTTP11ConnectionByteStream(self, request),
122133
extensions={
123134
"http_version": http_version,
124135
"reason_phrase": reason_phrase,
125-
"network_stream": self._network_stream,
136+
"network_stream": network_stream,
126137
},
127138
)
128139
except BaseException as exc:
@@ -167,7 +178,7 @@ def _send_event(
167178

168179
def _receive_response_headers(
169180
self, request: Request
170-
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
181+
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
171182
timeouts = request.extensions.get("timeout", {})
172183
timeout = timeouts.get("read", None)
173184

@@ -187,7 +198,9 @@ def _receive_response_headers(
187198
# raw header casing, rather than the enforced lowercase headers.
188199
headers = event.headers.raw_items()
189200

190-
return http_version, event.status_code, event.reason, headers
201+
trailing_data, _ = self._h11_state.trailing_data
202+
203+
return http_version, event.status_code, event.reason, headers, trailing_data
191204

192205
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
193206
timeouts = request.extensions.get("timeout", {})
@@ -340,3 +353,34 @@ def close(self) -> None:
340353
self._closed = True
341354
with Trace("response_closed", logger, self._request):
342355
self._connection._response_closed()
356+
357+
358+
class HTTP11UpgradeStream(NetworkStream):
359+
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
360+
self._stream = stream
361+
self._leading_data = leading_data
362+
363+
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
364+
if self._leading_data:
365+
buffer = self._leading_data[:max_bytes]
366+
self._leading_data = self._leading_data[max_bytes:]
367+
return buffer
368+
else:
369+
return self._stream.read(max_bytes, timeout)
370+
371+
def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
372+
self._stream.write(buffer, timeout)
373+
374+
def close(self) -> None:
375+
self._stream.close()
376+
377+
def start_tls(
378+
self,
379+
ssl_context: ssl.SSLContext,
380+
server_hostname: Optional[str] = None,
381+
timeout: Optional[float] = None,
382+
) -> NetworkStream:
383+
return self._stream.start_tls(ssl_context, server_hostname, timeout)
384+
385+
def get_extra_info(self, info: str) -> Any:
386+
return self._stream.get_extra_info(info)

tests/_async/test_http11.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,57 @@ async def test_http11_upgrade_connection():
269269
assert content == b"..."
270270

271271

272+
@pytest.mark.anyio
273+
async def test_http11_upgrade_with_trailing_data():
274+
"""
275+
HTTP "101 Switching Protocols" indicates an upgraded connection.
276+
277+
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
278+
in the h11.Connection object.
279+
280+
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
281+
"""
282+
origin = httpcore.Origin(b"wss", b"example.com", 443)
283+
stream = httpcore.AsyncMockStream(
284+
# The first element of this mock network stream buffer simulates networking
285+
# in which response headers and data are received at once.
286+
# This means that "foobar" becomes trailing data.
287+
[
288+
(
289+
b"HTTP/1.1 101 Switching Protocols\r\n"
290+
b"Connection: upgrade\r\n"
291+
b"Upgrade: custom\r\n"
292+
b"\r\n"
293+
b"foobar"
294+
),
295+
b"baz",
296+
]
297+
)
298+
async with httpcore.AsyncHTTP11Connection(
299+
origin=origin, stream=stream, keepalive_expiry=5.0
300+
) as conn:
301+
async with conn.stream(
302+
"GET",
303+
"wss://example.com/",
304+
headers={"Connection": "upgrade", "Upgrade": "custom"},
305+
) as response:
306+
assert response.status == 101
307+
network_stream = response.extensions["network_stream"]
308+
309+
content = await network_stream.read(max_bytes=3)
310+
assert content == b"foo"
311+
content = await network_stream.read(max_bytes=3)
312+
assert content == b"bar"
313+
content = await network_stream.read(max_bytes=3)
314+
assert content == b"baz"
315+
316+
# Lazy tests for AsyncHTTP11UpgradeStream
317+
await network_stream.write(b"spam")
318+
invalid = network_stream.get_extra_info("invalid")
319+
assert invalid is None
320+
await network_stream.aclose()
321+
322+
272323
@pytest.mark.anyio
273324
async def test_http11_early_hints():
274325
"""

tests/_sync/test_http11.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,57 @@ def test_http11_upgrade_connection():
270270

271271

272272

273+
def test_http11_upgrade_with_trailing_data():
274+
"""
275+
HTTP "101 Switching Protocols" indicates an upgraded connection.
276+
277+
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
278+
in the h11.Connection object.
279+
280+
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
281+
"""
282+
origin = httpcore.Origin(b"wss", b"example.com", 443)
283+
stream = httpcore.MockStream(
284+
# The first element of this mock network stream buffer simulates networking
285+
# in which response headers and data are received at once.
286+
# This means that "foobar" becomes trailing data.
287+
[
288+
(
289+
b"HTTP/1.1 101 Switching Protocols\r\n"
290+
b"Connection: upgrade\r\n"
291+
b"Upgrade: custom\r\n"
292+
b"\r\n"
293+
b"foobar"
294+
),
295+
b"baz",
296+
]
297+
)
298+
with httpcore.HTTP11Connection(
299+
origin=origin, stream=stream, keepalive_expiry=5.0
300+
) as conn:
301+
with conn.stream(
302+
"GET",
303+
"wss://example.com/",
304+
headers={"Connection": "upgrade", "Upgrade": "custom"},
305+
) as response:
306+
assert response.status == 101
307+
network_stream = response.extensions["network_stream"]
308+
309+
content = network_stream.read(max_bytes=3)
310+
assert content == b"foo"
311+
content = network_stream.read(max_bytes=3)
312+
assert content == b"bar"
313+
content = network_stream.read(max_bytes=3)
314+
assert content == b"baz"
315+
316+
# Lazy tests for HTTP11UpgradeStream
317+
network_stream.write(b"spam")
318+
invalid = network_stream.get_extra_info("invalid")
319+
assert invalid is None
320+
network_stream.close()
321+
322+
323+
273324
def test_http11_early_hints():
274325
"""
275326
HTTP "103 Early Hints" is an interim response.

0 commit comments

Comments
 (0)