Skip to content

Commit 4cc98a3

Browse files
authored
[PR #10137/25c7f23 backport][3.11] Restore zero copy writes on Python 3.12.9+/3.13.2+ (#10390)
1 parent c60a8f4 commit 4cc98a3

File tree

4 files changed

+130
-5
lines changed

4 files changed

+130
-5
lines changed

.github/workflows/ci-cd.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,11 @@ jobs:
250250
uses: actions/checkout@v4
251251
with:
252252
submodules: true
253-
- name: Setup Python 3.13
253+
- name: Setup Python 3.13.2
254254
id: python-install
255255
uses: actions/setup-python@v5
256256
with:
257-
python-version: 3.13
257+
python-version: 3.13.2
258258
cache: pip
259259
cache-dependency-path: requirements/*.txt
260260
- name: Update pip, wheel, setuptools, build, twine

CHANGES/10137.misc.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`.
2+
3+
Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions.

aiohttp/http_writer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Http related parsers and protocol."""
22

33
import asyncio
4+
import sys
45
import zlib
56
from typing import ( # noqa
67
Any,
@@ -24,6 +25,17 @@
2425
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
2526

2627

28+
MIN_PAYLOAD_FOR_WRITELINES = 2048
29+
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
30+
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
31+
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
32+
# writelines is not safe for use
33+
# on Python 3.12+ until 3.12.9
34+
# on Python 3.13+ until 3.13.2
35+
# and on older versions it not any faster than write
36+
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
37+
38+
2739
class HttpVersion(NamedTuple):
2840
major: int
2941
minor: int
@@ -90,7 +102,10 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
90102
transport = self._protocol.transport
91103
if transport is None or transport.is_closing():
92104
raise ClientConnectionResetError("Cannot write to closing transport")
93-
transport.write(b"".join(chunks))
105+
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
106+
transport.write(b"".join(chunks))
107+
else:
108+
transport.writelines(chunks)
94109

95110
async def write(
96111
self,

tests/test_http_writer.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import array
33
import asyncio
44
import zlib
5-
from typing import Iterable
5+
from typing import Generator, Iterable
66
from unittest import mock
77

88
import pytest
@@ -14,7 +14,19 @@
1414

1515

1616
@pytest.fixture
17-
def buf():
17+
def enable_writelines() -> Generator[None, None, None]:
18+
with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False):
19+
yield
20+
21+
22+
@pytest.fixture
23+
def force_writelines_small_payloads() -> Generator[None, None, None]:
24+
with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1):
25+
yield
26+
27+
28+
@pytest.fixture
29+
def buf() -> bytearray:
1830
return bytearray()
1931

2032

@@ -117,6 +129,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
117129
assert zlib.decompress(content) == (b"data" * 4096) + payload
118130

119131

132+
@pytest.mark.usefixtures("enable_writelines")
133+
async def test_write_large_payload_deflate_compression_data_in_eof_writelines(
134+
protocol: BaseProtocol,
135+
transport: asyncio.Transport,
136+
loop: asyncio.AbstractEventLoop,
137+
) -> None:
138+
msg = http.StreamWriter(protocol, loop)
139+
msg.enable_compression("deflate")
140+
141+
await msg.write(b"data" * 4096)
142+
assert transport.write.called # type: ignore[attr-defined]
143+
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
144+
transport.write.reset_mock() # type: ignore[attr-defined]
145+
assert not transport.writelines.called # type: ignore[attr-defined]
146+
147+
# This payload compresses to 20447 bytes
148+
payload = b"".join(
149+
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
150+
)
151+
await msg.write_eof(payload)
152+
assert not transport.write.called # type: ignore[attr-defined]
153+
assert transport.writelines.called # type: ignore[attr-defined]
154+
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
155+
content = b"".join(chunks)
156+
assert zlib.decompress(content) == (b"data" * 4096) + payload
157+
158+
120159
async def test_write_payload_chunked_filter(
121160
protocol: BaseProtocol,
122161
transport: asyncio.Transport,
@@ -185,6 +224,26 @@ async def test_write_payload_deflate_compression_chunked(
185224
assert content == expected
186225

187226

227+
@pytest.mark.usefixtures("enable_writelines")
228+
@pytest.mark.usefixtures("force_writelines_small_payloads")
229+
async def test_write_payload_deflate_compression_chunked_writelines(
230+
protocol: BaseProtocol,
231+
transport: asyncio.Transport,
232+
loop: asyncio.AbstractEventLoop,
233+
) -> None:
234+
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
235+
msg = http.StreamWriter(protocol, loop)
236+
msg.enable_compression("deflate")
237+
msg.enable_chunking()
238+
await msg.write(b"data")
239+
await msg.write_eof()
240+
241+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
242+
assert all(chunks)
243+
content = b"".join(chunks)
244+
assert content == expected
245+
246+
188247
async def test_write_payload_deflate_and_chunked(
189248
buf: bytearray,
190249
protocol: BaseProtocol,
@@ -221,6 +280,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
221280
assert content == expected
222281

223282

283+
@pytest.mark.usefixtures("enable_writelines")
284+
@pytest.mark.usefixtures("force_writelines_small_payloads")
285+
async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines(
286+
protocol: BaseProtocol,
287+
transport: asyncio.Transport,
288+
loop: asyncio.AbstractEventLoop,
289+
) -> None:
290+
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
291+
msg = http.StreamWriter(protocol, loop)
292+
msg.enable_compression("deflate")
293+
msg.enable_chunking()
294+
await msg.write(b"data")
295+
await msg.write_eof(b"end")
296+
297+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
298+
assert all(chunks)
299+
content = b"".join(chunks)
300+
assert content == expected
301+
302+
224303
async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
225304
protocol: BaseProtocol,
226305
transport: asyncio.Transport,
@@ -247,6 +326,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
247326
assert zlib.decompress(content) == (b"data" * 4096) + payload
248327

249328

329+
@pytest.mark.usefixtures("enable_writelines")
330+
@pytest.mark.usefixtures("force_writelines_small_payloads")
331+
async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines(
332+
protocol: BaseProtocol,
333+
transport: asyncio.Transport,
334+
loop: asyncio.AbstractEventLoop,
335+
) -> None:
336+
msg = http.StreamWriter(protocol, loop)
337+
msg.enable_compression("deflate")
338+
msg.enable_chunking()
339+
340+
await msg.write(b"data" * 4096)
341+
# This payload compresses to 1111 bytes
342+
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
343+
await msg.write_eof(payload)
344+
assert not transport.write.called # type: ignore[attr-defined]
345+
346+
chunks = []
347+
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
348+
chunked_payload = list(write_lines_call[1][0])[1:]
349+
chunked_payload.pop()
350+
chunks.extend(chunked_payload)
351+
352+
assert all(chunks)
353+
content = b"".join(chunks)
354+
assert zlib.decompress(content) == (b"data" * 4096) + payload
355+
356+
250357
async def test_write_payload_deflate_compression_chunked_connection_lost(
251358
protocol: BaseProtocol,
252359
transport: asyncio.Transport,

0 commit comments

Comments
 (0)