1
1
# Tests for aiohttp/http_writer.py
2
2
import array
3
3
import asyncio
4
+ import zlib
5
+ from typing import Iterable
4
6
from unittest import mock
5
7
6
8
import pytest
@@ -23,7 +25,12 @@ def transport(buf):
23
25
def write (chunk ):
24
26
buf .extend (chunk )
25
27
28
+ def writelines (chunks : Iterable [bytes ]) -> None :
29
+ for chunk in chunks :
30
+ buf .extend (chunk )
31
+
26
32
transport .write .side_effect = write
33
+ transport .writelines .side_effect = writelines
27
34
transport .is_closing .return_value = False
28
35
return transport
29
36
@@ -85,21 +92,53 @@ async def test_write_payload_length(protocol, transport, loop) -> None:
85
92
assert b"da" == content .split (b"\r \n \r \n " , 1 )[- 1 ]
86
93
87
94
88
- async def test_write_payload_chunked_filter (protocol , transport , loop ) -> None :
89
- write = transport .write = mock .Mock ()
95
+ async def test_write_large_payload_deflate_compression_data_in_eof (
96
+ protocol : BaseProtocol ,
97
+ transport : asyncio .Transport ,
98
+ loop : asyncio .AbstractEventLoop ,
99
+ ) -> None :
100
+ msg = http .StreamWriter (protocol , loop )
101
+ msg .enable_compression ("deflate" )
102
+
103
+ await msg .write (b"data" * 4096 )
104
+ assert transport .write .called # type: ignore[attr-defined]
105
+ chunks = [c [1 ][0 ] for c in list (transport .write .mock_calls )] # type: ignore[attr-defined]
106
+ transport .write .reset_mock () # type: ignore[attr-defined]
107
+ assert not transport .writelines .called # type: ignore[attr-defined]
90
108
109
+ # This payload compresses to 20447 bytes
110
+ payload = b"" .join (
111
+ [bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 ) for _ in range (64 )]
112
+ )
113
+ await msg .write_eof (payload )
114
+ assert not transport .write .called # type: ignore[attr-defined]
115
+ assert transport .writelines .called # type: ignore[attr-defined]
116
+ chunks .extend (transport .writelines .mock_calls [0 ][1 ][0 ]) # type: ignore[attr-defined]
117
+ content = b"" .join (chunks )
118
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
119
+
120
+
121
+ async def test_write_payload_chunked_filter (
122
+ protocol : BaseProtocol ,
123
+ transport : asyncio .Transport ,
124
+ loop : asyncio .AbstractEventLoop ,
125
+ ) -> None :
91
126
msg = http .StreamWriter (protocol , loop )
92
127
msg .enable_chunking ()
93
128
await msg .write (b"da" )
94
129
await msg .write (b"ta" )
95
130
await msg .write_eof ()
96
131
97
- content = b"" .join ([c [1 ][0 ] for c in list (write .mock_calls )])
132
+ content = b"" .join ([b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )]) # type: ignore[attr-defined]
133
+ content += b"" .join ([c [1 ][0 ] for c in list (transport .write .mock_calls )]) # type: ignore[attr-defined]
98
134
assert content .endswith (b"2\r \n da\r \n 2\r \n ta\r \n 0\r \n \r \n " )
99
135
100
136
101
- async def test_write_payload_chunked_filter_mutiple_chunks (protocol , transport , loop ):
102
- write = transport .write = mock .Mock ()
137
+ async def test_write_payload_chunked_filter_multiple_chunks (
138
+ protocol : BaseProtocol ,
139
+ transport : asyncio .Transport ,
140
+ loop : asyncio .AbstractEventLoop ,
141
+ ) -> None :
103
142
msg = http .StreamWriter (protocol , loop )
104
143
msg .enable_chunking ()
105
144
await msg .write (b"da" )
@@ -108,14 +147,14 @@ async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport,
108
147
await msg .write (b"at" )
109
148
await msg .write (b"a2" )
110
149
await msg .write_eof ()
111
- content = b"" .join ([c [1 ][0 ] for c in list (write .mock_calls )])
150
+ content = b"" .join ([b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )]) # type: ignore[attr-defined]
151
+ content += b"" .join ([c [1 ][0 ] for c in list (transport .write .mock_calls )]) # type: ignore[attr-defined]
112
152
assert content .endswith (
113
153
b"2\r \n da\r \n 2\r \n ta\r \n 2\r \n 1d\r \n 2\r \n at\r \n 2\r \n a2\r \n 0\r \n \r \n "
114
154
)
115
155
116
156
117
157
async def test_write_payload_deflate_compression (protocol , transport , loop ) -> None :
118
-
119
158
COMPRESSED = b"x\x9c KI,I\x04 \x00 \x04 \x00 \x01 \x9b "
120
159
write = transport .write = mock .Mock ()
121
160
msg = http .StreamWriter (protocol , loop )
@@ -129,7 +168,30 @@ async def test_write_payload_deflate_compression(protocol, transport, loop) -> N
129
168
assert COMPRESSED == content .split (b"\r \n \r \n " , 1 )[- 1 ]
130
169
131
170
132
- async def test_write_payload_deflate_and_chunked (buf , protocol , transport , loop ):
171
+ async def test_write_payload_deflate_compression_chunked (
172
+ protocol : BaseProtocol ,
173
+ transport : asyncio .Transport ,
174
+ loop : asyncio .AbstractEventLoop ,
175
+ ) -> None :
176
+ expected = b"2\r \n x\x9c \r \n a\r \n KI,I\x04 \x00 \x04 \x00 \x01 \x9b \r \n 0\r \n \r \n "
177
+ msg = http .StreamWriter (protocol , loop )
178
+ msg .enable_compression ("deflate" )
179
+ msg .enable_chunking ()
180
+ await msg .write (b"data" )
181
+ await msg .write_eof ()
182
+
183
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
184
+ assert all (chunks )
185
+ content = b"" .join (chunks )
186
+ assert content == expected
187
+
188
+
189
+ async def test_write_payload_deflate_and_chunked (
190
+ buf : bytearray ,
191
+ protocol : BaseProtocol ,
192
+ transport : asyncio .Transport ,
193
+ loop : asyncio .AbstractEventLoop ,
194
+ ) -> None :
133
195
msg = http .StreamWriter (protocol , loop )
134
196
msg .enable_compression ("deflate" )
135
197
msg .enable_chunking ()
@@ -142,8 +204,71 @@ async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop)
142
204
assert thing == buf
143
205
144
206
145
- async def test_write_payload_bytes_memoryview (buf , protocol , transport , loop ):
207
+ async def test_write_payload_deflate_compression_chunked_data_in_eof (
208
+ protocol : BaseProtocol ,
209
+ transport : asyncio .Transport ,
210
+ loop : asyncio .AbstractEventLoop ,
211
+ ) -> None :
212
+ expected = b"2\r \n x\x9c \r \n d\r \n KI,IL\xcd K\x01 \x00 \x0b @\x02 \xd2 \r \n 0\r \n \r \n "
213
+ msg = http .StreamWriter (protocol , loop )
214
+ msg .enable_compression ("deflate" )
215
+ msg .enable_chunking ()
216
+ await msg .write (b"data" )
217
+ await msg .write_eof (b"end" )
218
+
219
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
220
+ assert all (chunks )
221
+ content = b"" .join (chunks )
222
+ assert content == expected
223
+
224
+
225
+ async def test_write_large_payload_deflate_compression_chunked_data_in_eof (
226
+ protocol : BaseProtocol ,
227
+ transport : asyncio .Transport ,
228
+ loop : asyncio .AbstractEventLoop ,
229
+ ) -> None :
230
+ msg = http .StreamWriter (protocol , loop )
231
+ msg .enable_compression ("deflate" )
232
+ msg .enable_chunking ()
233
+
234
+ await msg .write (b"data" * 4096 )
235
+ # This payload compresses to 1111 bytes
236
+ payload = b"" .join ([bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 )])
237
+ await msg .write_eof (payload )
238
+ assert not transport .write .called # type: ignore[attr-defined]
146
239
240
+ chunks = []
241
+ for write_lines_call in transport .writelines .mock_calls : # type: ignore[attr-defined]
242
+ chunked_payload = list (write_lines_call [1 ][0 ])[1 :]
243
+ chunked_payload .pop ()
244
+ chunks .extend (chunked_payload )
245
+
246
+ assert all (chunks )
247
+ content = b"" .join (chunks )
248
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
249
+
250
+
251
+ async def test_write_payload_deflate_compression_chunked_connection_lost (
252
+ protocol : BaseProtocol ,
253
+ transport : asyncio .Transport ,
254
+ loop : asyncio .AbstractEventLoop ,
255
+ ) -> None :
256
+ msg = http .StreamWriter (protocol , loop )
257
+ msg .enable_compression ("deflate" )
258
+ msg .enable_chunking ()
259
+ await msg .write (b"data" )
260
+ with pytest .raises (
261
+ ClientConnectionResetError , match = "Cannot write to closing transport"
262
+ ), mock .patch .object (transport , "is_closing" , return_value = True ):
263
+ await msg .write_eof (b"end" )
264
+
265
+
266
+ async def test_write_payload_bytes_memoryview (
267
+ buf : bytearray ,
268
+ protocol : BaseProtocol ,
269
+ transport : asyncio .Transport ,
270
+ loop : asyncio .AbstractEventLoop ,
271
+ ) -> None :
147
272
msg = http .StreamWriter (protocol , loop )
148
273
149
274
mv = memoryview (b"abcd" )
0 commit comments