Skip to content

Commit 354489d

Browse files
authored
[PR #9839/a9a0d84 backport][3.11] Implement zero copy writes in StreamWriter (#9847)
1 parent c39032b commit 354489d

File tree

4 files changed

+191
-34
lines changed

4 files changed

+191
-34
lines changed

CHANGES/9839.misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`.

aiohttp/http_writer.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
import asyncio
44
import zlib
5-
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa
5+
from typing import ( # noqa
6+
Any,
7+
Awaitable,
8+
Callable,
9+
Iterable,
10+
List,
11+
NamedTuple,
12+
Optional,
13+
Union,
14+
)
615

716
from multidict import CIMultiDict
817

@@ -76,6 +85,17 @@ def _write(self, chunk: bytes) -> None:
7685
raise ClientConnectionResetError("Cannot write to closing transport")
7786
transport.write(chunk)
7887

88+
def _writelines(self, chunks: Iterable[bytes]) -> None:
89+
size = 0
90+
for chunk in chunks:
91+
size += len(chunk)
92+
self.buffer_size += size
93+
self.output_size += size
94+
transport = self._protocol.transport
95+
if transport is None or transport.is_closing():
96+
raise ClientConnectionResetError("Cannot write to closing transport")
97+
transport.writelines(chunks)
98+
7999
async def write(
80100
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
81101
) -> None:
@@ -110,10 +130,11 @@ async def write(
110130

111131
if chunk:
112132
if self.chunked:
113-
chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
114-
chunk = chunk_len_pre + chunk + b"\r\n"
115-
116-
self._write(chunk)
133+
self._writelines(
134+
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
135+
)
136+
else:
137+
self._write(chunk)
117138

118139
if self.buffer_size > LIMIT and drain:
119140
self.buffer_size = 0
@@ -142,22 +163,31 @@ async def write_eof(self, chunk: bytes = b"") -> None:
142163
await self._on_chunk_sent(chunk)
143164

144165
if self._compress:
145-
if chunk:
146-
chunk = await self._compress.compress(chunk)
166+
chunks: List[bytes] = []
167+
chunks_len = 0
168+
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
169+
chunks_len = len(compressed_chunk)
170+
chunks.append(compressed_chunk)
147171

148-
chunk += self._compress.flush()
149-
if chunk and self.chunked:
150-
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
151-
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
152-
else:
153-
if self.chunked:
154-
if chunk:
155-
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
156-
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
157-
else:
158-
chunk = b"0\r\n\r\n"
172+
flush_chunk = self._compress.flush()
173+
chunks_len += len(flush_chunk)
174+
chunks.append(flush_chunk)
175+
assert chunks_len
159176

160-
if chunk:
177+
if self.chunked:
178+
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
179+
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
180+
elif len(chunks) > 1:
181+
self._writelines(chunks)
182+
else:
183+
self._write(chunks[0])
184+
elif self.chunked:
185+
if chunk:
186+
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
187+
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
188+
else:
189+
self._write(b"0\r\n\r\n")
190+
elif chunk:
161191
self._write(chunk)
162192

163193
await self.drain()

tests/test_client_request.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import urllib.parse
77
import zlib
88
from http.cookies import BaseCookie, Morsel, SimpleCookie
9-
from typing import Any, Callable, Dict, Optional
9+
from typing import Any, Callable, Dict, Iterable, Optional
1010
from unittest import mock
1111

1212
import pytest
@@ -67,17 +67,18 @@ def protocol(loop, transport):
6767

6868

6969
@pytest.fixture
70-
def transport(buf):
71-
transport = mock.Mock()
70+
def transport(buf: bytearray) -> mock.Mock:
71+
transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True)
7272

7373
def write(chunk):
7474
buf.extend(chunk)
7575

76-
async def write_eof():
77-
pass
76+
def writelines(chunks: Iterable[bytes]) -> None:
77+
for chunk in chunks:
78+
buf.extend(chunk)
7879

7980
transport.write.side_effect = write
80-
transport.write_eof.side_effect = write_eof
81+
transport.writelines.side_effect = writelines
8182
transport.is_closing.return_value = False
8283

8384
return transport

tests/test_http_writer.py

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Tests for aiohttp/http_writer.py
22
import array
33
import asyncio
4+
import zlib
5+
from typing import Iterable
46
from unittest import mock
57

68
import pytest
@@ -23,7 +25,12 @@ def transport(buf):
2325
def write(chunk):
2426
buf.extend(chunk)
2527

28+
def writelines(chunks: Iterable[bytes]) -> None:
29+
for chunk in chunks:
30+
buf.extend(chunk)
31+
2632
transport.write.side_effect = write
33+
transport.writelines.side_effect = writelines
2734
transport.is_closing.return_value = False
2835
return transport
2936

@@ -85,21 +92,53 @@ async def test_write_payload_length(protocol, transport, loop) -> None:
8592
assert b"da" == content.split(b"\r\n\r\n", 1)[-1]
8693

8794

88-
async def test_write_payload_chunked_filter(protocol, transport, loop) -> None:
89-
write = transport.write = mock.Mock()
95+
async def test_write_large_payload_deflate_compression_data_in_eof(
96+
protocol: BaseProtocol,
97+
transport: asyncio.Transport,
98+
loop: asyncio.AbstractEventLoop,
99+
) -> None:
100+
msg = http.StreamWriter(protocol, loop)
101+
msg.enable_compression("deflate")
102+
103+
await msg.write(b"data" * 4096)
104+
assert transport.write.called # type: ignore[attr-defined]
105+
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
106+
transport.write.reset_mock() # type: ignore[attr-defined]
107+
assert not transport.writelines.called # type: ignore[attr-defined]
90108

109+
# This payload compresses to 20447 bytes
110+
payload = b"".join(
111+
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
112+
)
113+
await msg.write_eof(payload)
114+
assert not transport.write.called # type: ignore[attr-defined]
115+
assert transport.writelines.called # type: ignore[attr-defined]
116+
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
117+
content = b"".join(chunks)
118+
assert zlib.decompress(content) == (b"data" * 4096) + payload
119+
120+
121+
async def test_write_payload_chunked_filter(
122+
protocol: BaseProtocol,
123+
transport: asyncio.Transport,
124+
loop: asyncio.AbstractEventLoop,
125+
) -> None:
91126
msg = http.StreamWriter(protocol, loop)
92127
msg.enable_chunking()
93128
await msg.write(b"da")
94129
await msg.write(b"ta")
95130
await msg.write_eof()
96131

97-
content = b"".join([c[1][0] for c in list(write.mock_calls)])
132+
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
133+
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
98134
assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n")
99135

100136

101-
async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop):
102-
write = transport.write = mock.Mock()
137+
async def test_write_payload_chunked_filter_multiple_chunks(
138+
protocol: BaseProtocol,
139+
transport: asyncio.Transport,
140+
loop: asyncio.AbstractEventLoop,
141+
) -> None:
103142
msg = http.StreamWriter(protocol, loop)
104143
msg.enable_chunking()
105144
await msg.write(b"da")
@@ -108,14 +147,14 @@ async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport,
108147
await msg.write(b"at")
109148
await msg.write(b"a2")
110149
await msg.write_eof()
111-
content = b"".join([c[1][0] for c in list(write.mock_calls)])
150+
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
151+
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
112152
assert content.endswith(
113153
b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n"
114154
)
115155

116156

117157
async def test_write_payload_deflate_compression(protocol, transport, loop) -> None:
118-
119158
COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b"
120159
write = transport.write = mock.Mock()
121160
msg = http.StreamWriter(protocol, loop)
@@ -129,7 +168,30 @@ async def test_write_payload_deflate_compression(protocol, transport, loop) -> N
129168
assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1]
130169

131170

132-
async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop):
171+
async def test_write_payload_deflate_compression_chunked(
172+
protocol: BaseProtocol,
173+
transport: asyncio.Transport,
174+
loop: asyncio.AbstractEventLoop,
175+
) -> None:
176+
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
177+
msg = http.StreamWriter(protocol, loop)
178+
msg.enable_compression("deflate")
179+
msg.enable_chunking()
180+
await msg.write(b"data")
181+
await msg.write_eof()
182+
183+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
184+
assert all(chunks)
185+
content = b"".join(chunks)
186+
assert content == expected
187+
188+
189+
async def test_write_payload_deflate_and_chunked(
190+
buf: bytearray,
191+
protocol: BaseProtocol,
192+
transport: asyncio.Transport,
193+
loop: asyncio.AbstractEventLoop,
194+
) -> None:
133195
msg = http.StreamWriter(protocol, loop)
134196
msg.enable_compression("deflate")
135197
msg.enable_chunking()
@@ -142,8 +204,71 @@ async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop)
142204
assert thing == buf
143205

144206

145-
async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop):
207+
async def test_write_payload_deflate_compression_chunked_data_in_eof(
208+
protocol: BaseProtocol,
209+
transport: asyncio.Transport,
210+
loop: asyncio.AbstractEventLoop,
211+
) -> None:
212+
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
213+
msg = http.StreamWriter(protocol, loop)
214+
msg.enable_compression("deflate")
215+
msg.enable_chunking()
216+
await msg.write(b"data")
217+
await msg.write_eof(b"end")
218+
219+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
220+
assert all(chunks)
221+
content = b"".join(chunks)
222+
assert content == expected
223+
224+
225+
async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
226+
protocol: BaseProtocol,
227+
transport: asyncio.Transport,
228+
loop: asyncio.AbstractEventLoop,
229+
) -> None:
230+
msg = http.StreamWriter(protocol, loop)
231+
msg.enable_compression("deflate")
232+
msg.enable_chunking()
233+
234+
await msg.write(b"data" * 4096)
235+
# This payload compresses to 1111 bytes
236+
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
237+
await msg.write_eof(payload)
238+
assert not transport.write.called # type: ignore[attr-defined]
146239

240+
chunks = []
241+
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
242+
chunked_payload = list(write_lines_call[1][0])[1:]
243+
chunked_payload.pop()
244+
chunks.extend(chunked_payload)
245+
246+
assert all(chunks)
247+
content = b"".join(chunks)
248+
assert zlib.decompress(content) == (b"data" * 4096) + payload
249+
250+
251+
async def test_write_payload_deflate_compression_chunked_connection_lost(
252+
protocol: BaseProtocol,
253+
transport: asyncio.Transport,
254+
loop: asyncio.AbstractEventLoop,
255+
) -> None:
256+
msg = http.StreamWriter(protocol, loop)
257+
msg.enable_compression("deflate")
258+
msg.enable_chunking()
259+
await msg.write(b"data")
260+
with pytest.raises(
261+
ClientConnectionResetError, match="Cannot write to closing transport"
262+
), mock.patch.object(transport, "is_closing", return_value=True):
263+
await msg.write_eof(b"end")
264+
265+
266+
async def test_write_payload_bytes_memoryview(
267+
buf: bytearray,
268+
protocol: BaseProtocol,
269+
transport: asyncio.Transport,
270+
loop: asyncio.AbstractEventLoop,
271+
) -> None:
147272
msg = http.StreamWriter(protocol, loop)
148273

149274
mv = memoryview(b"abcd")

0 commit comments

Comments
 (0)