2
2
import array
3
3
import asyncio
4
4
import zlib
5
- from typing import Iterable
5
+ from typing import Generator , Iterable
6
6
from unittest import mock
7
7
8
8
import pytest
14
14
15
15
16
16
@pytest .fixture
17
- def buf ():
17
+ def enable_writelines () -> Generator [None , None , None ]:
18
+ with mock .patch ("aiohttp.http_writer.SKIP_WRITELINES" , False ):
19
+ yield
20
+
21
+
22
+ @pytest .fixture
23
+ def force_writelines_small_payloads () -> Generator [None , None , None ]:
24
+ with mock .patch ("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES" , 1 ):
25
+ yield
26
+
27
+
28
+ @pytest .fixture
29
+ def buf () -> bytearray :
18
30
return bytearray ()
19
31
20
32
@@ -117,6 +129,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
117
129
assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
118
130
119
131
132
+ @pytest .mark .usefixtures ("enable_writelines" )
133
+ async def test_write_large_payload_deflate_compression_data_in_eof_writelines (
134
+ protocol : BaseProtocol ,
135
+ transport : asyncio .Transport ,
136
+ loop : asyncio .AbstractEventLoop ,
137
+ ) -> None :
138
+ msg = http .StreamWriter (protocol , loop )
139
+ msg .enable_compression ("deflate" )
140
+
141
+ await msg .write (b"data" * 4096 )
142
+ assert transport .write .called # type: ignore[attr-defined]
143
+ chunks = [c [1 ][0 ] for c in list (transport .write .mock_calls )] # type: ignore[attr-defined]
144
+ transport .write .reset_mock () # type: ignore[attr-defined]
145
+ assert not transport .writelines .called # type: ignore[attr-defined]
146
+
147
+ # This payload compresses to 20447 bytes
148
+ payload = b"" .join (
149
+ [bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 ) for _ in range (64 )]
150
+ )
151
+ await msg .write_eof (payload )
152
+ assert not transport .write .called # type: ignore[attr-defined]
153
+ assert transport .writelines .called # type: ignore[attr-defined]
154
+ chunks .extend (transport .writelines .mock_calls [0 ][1 ][0 ]) # type: ignore[attr-defined]
155
+ content = b"" .join (chunks )
156
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
157
+
158
+
120
159
async def test_write_payload_chunked_filter (
121
160
protocol : BaseProtocol ,
122
161
transport : asyncio .Transport ,
@@ -185,6 +224,26 @@ async def test_write_payload_deflate_compression_chunked(
185
224
assert content == expected
186
225
187
226
227
+ @pytest .mark .usefixtures ("enable_writelines" )
228
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
229
+ async def test_write_payload_deflate_compression_chunked_writelines (
230
+ protocol : BaseProtocol ,
231
+ transport : asyncio .Transport ,
232
+ loop : asyncio .AbstractEventLoop ,
233
+ ) -> None :
234
+ expected = b"2\r \n x\x9c \r \n a\r \n KI,I\x04 \x00 \x04 \x00 \x01 \x9b \r \n 0\r \n \r \n "
235
+ msg = http .StreamWriter (protocol , loop )
236
+ msg .enable_compression ("deflate" )
237
+ msg .enable_chunking ()
238
+ await msg .write (b"data" )
239
+ await msg .write_eof ()
240
+
241
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
242
+ assert all (chunks )
243
+ content = b"" .join (chunks )
244
+ assert content == expected
245
+
246
+
188
247
async def test_write_payload_deflate_and_chunked (
189
248
buf : bytearray ,
190
249
protocol : BaseProtocol ,
@@ -221,6 +280,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
221
280
assert content == expected
222
281
223
282
283
+ @pytest .mark .usefixtures ("enable_writelines" )
284
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
285
+ async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines (
286
+ protocol : BaseProtocol ,
287
+ transport : asyncio .Transport ,
288
+ loop : asyncio .AbstractEventLoop ,
289
+ ) -> None :
290
+ expected = b"2\r \n x\x9c \r \n d\r \n KI,IL\xcd K\x01 \x00 \x0b @\x02 \xd2 \r \n 0\r \n \r \n "
291
+ msg = http .StreamWriter (protocol , loop )
292
+ msg .enable_compression ("deflate" )
293
+ msg .enable_chunking ()
294
+ await msg .write (b"data" )
295
+ await msg .write_eof (b"end" )
296
+
297
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
298
+ assert all (chunks )
299
+ content = b"" .join (chunks )
300
+ assert content == expected
301
+
302
+
224
303
async def test_write_large_payload_deflate_compression_chunked_data_in_eof (
225
304
protocol : BaseProtocol ,
226
305
transport : asyncio .Transport ,
@@ -247,6 +326,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
247
326
assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
248
327
249
328
329
+ @pytest .mark .usefixtures ("enable_writelines" )
330
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
331
+ async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines (
332
+ protocol : BaseProtocol ,
333
+ transport : asyncio .Transport ,
334
+ loop : asyncio .AbstractEventLoop ,
335
+ ) -> None :
336
+ msg = http .StreamWriter (protocol , loop )
337
+ msg .enable_compression ("deflate" )
338
+ msg .enable_chunking ()
339
+
340
+ await msg .write (b"data" * 4096 )
341
+ # This payload compresses to 1111 bytes
342
+ payload = b"" .join ([bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 )])
343
+ await msg .write_eof (payload )
344
+ assert not transport .write .called # type: ignore[attr-defined]
345
+
346
+ chunks = []
347
+ for write_lines_call in transport .writelines .mock_calls : # type: ignore[attr-defined]
348
+ chunked_payload = list (write_lines_call [1 ][0 ])[1 :]
349
+ chunked_payload .pop ()
350
+ chunks .extend (chunked_payload )
351
+
352
+ assert all (chunks )
353
+ content = b"" .join (chunks )
354
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
355
+
356
+
250
357
async def test_write_payload_deflate_compression_chunked_connection_lost (
251
358
protocol : BaseProtocol ,
252
359
transport : asyncio .Transport ,
0 commit comments