Skip to content

Commit a5f5205

Browse files
Fixes for strict-bytes (#10454)
1 parent def4e4d commit a5f5205

File tree

8 files changed

+48
-21
lines changed

8 files changed

+48
-21
lines changed

.mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pretty = True
2020
show_column_numbers = True
2121
show_error_codes = True
2222
show_error_code_links = True
23+
strict_bytes = True
2324
strict_equality = True
2425
warn_incomplete_stub = True
2526
warn_redundant_casts = True

aiohttp/_websocket/reader_py.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,11 @@ def _feed_data(self, data: bytes) -> None:
309309
self.queue.feed_data(msg)
310310
elif opcode == OP_CODE_PING:
311311
self.queue.feed_data(
312-
WSMessagePing(data=payload, size=len(payload), extra="")
312+
WSMessagePing(data=bytes(payload), size=len(payload), extra="")
313313
)
314314
elif opcode == OP_CODE_PONG:
315315
self.queue.feed_data(
316-
WSMessagePong(data=payload, size=len(payload), extra="")
316+
WSMessagePong(data=bytes(payload), size=len(payload), extra="")
317317
)
318318
else:
319319
raise WebSocketError(

aiohttp/_websocket/writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ async def send_frame(
133133
# when aiohttp is acting as a client. Servers do not use a mask.
134134
if use_mask:
135135
mask = PACK_RANDBITS(self.get_random_bits())
136-
message = bytearray(message)
137-
websocket_mask(mask, message)
138-
self.transport.write(header + mask + message)
136+
message_arr = bytearray(message)
137+
websocket_mask(mask, message_arr)
138+
self.transport.write(header + mask + message_arr)
139139
self._output_size += MASK_LEN
140140
elif msg_length > MSG_SIZE:
141141
self.transport.write(header)

aiohttp/abc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ class AbstractStreamWriter(ABC):
202202
length: Optional[int] = 0
203203

204204
@abstractmethod
205-
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
205+
async def write(
206+
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
207+
) -> None:
206208
"""Write chunk into stream."""
207209

208210
@abstractmethod

aiohttp/compression_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import asyncio
2+
import sys
23
import zlib
34
from concurrent.futures import Executor
45
from typing import Optional, cast
56

7+
if sys.version_info >= (3, 12):
8+
from collections.abc import Buffer
9+
else:
10+
from typing import Union
11+
12+
Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
13+
614
try:
715
try:
816
import brotlicffi as brotli
@@ -66,10 +74,10 @@ def __init__(
6674
)
6775
self._compress_lock = asyncio.Lock()
6876

69-
def compress_sync(self, data: bytes) -> bytes:
77+
def compress_sync(self, data: Buffer) -> bytes:
7078
return self._compressor.compress(data)
7179

72-
async def compress(self, data: bytes) -> bytes:
80+
async def compress(self, data: Buffer) -> bytes:
7381
"""Compress the data and returned the compressed bytes.
7482
7583
Note that flush() must be called after the last call to compress()
@@ -111,10 +119,10 @@ def __init__(
111119
)
112120
self._decompressor = zlib.decompressobj(wbits=self._mode)
113121

114-
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
122+
def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes:
115123
return self._decompressor.decompress(data, max_length)
116124

117-
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
125+
async def decompress(self, data: Buffer, max_length: int = 0) -> bytes:
118126
"""Decompress the data and return the decompressed bytes.
119127
120128
If the data size is large than the max_sync_chunk_size, the decompression
@@ -162,7 +170,7 @@ def __init__(self) -> None:
162170
)
163171
self._obj = brotli.Decompressor()
164172

165-
def decompress_sync(self, data: bytes) -> bytes:
173+
def decompress_sync(self, data: Buffer) -> bytes:
166174
if hasattr(self._obj, "decompress"):
167175
return cast(bytes, self._obj.decompress(data))
168176
return cast(bytes, self._obj.process(data))

aiohttp/http_writer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ class HttpVersion(NamedTuple):
4545
HttpVersion11 = HttpVersion(1, 1)
4646

4747

48-
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
48+
_T_OnChunkSent = Optional[
49+
Callable[
50+
[Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]],
51+
Awaitable[None],
52+
]
53+
]
4954
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
5055

5156

@@ -84,16 +89,23 @@ def enable_compression(
8489
) -> None:
8590
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
8691

87-
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
92+
def _write(
93+
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
94+
) -> None:
8895
size = len(chunk)
8996
self.buffer_size += size
9097
self.output_size += size
9198
transport = self._protocol.transport
9299
if transport is None or transport.is_closing():
93100
raise ClientConnectionResetError("Cannot write to closing transport")
94-
transport.write(chunk)
101+
transport.write(chunk) # type: ignore[arg-type]
95102

96-
def _writelines(self, chunks: Iterable[bytes]) -> None:
103+
def _writelines(
104+
self,
105+
chunks: Iterable[
106+
Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
107+
],
108+
) -> None:
97109
size = 0
98110
for chunk in chunks:
99111
size += len(chunk)
@@ -105,11 +117,11 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
105117
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
106118
transport.write(b"".join(chunks))
107119
else:
108-
transport.writelines(chunks)
120+
transport.writelines(chunks) # type: ignore[arg-type]
109121

110122
async def write(
111123
self,
112-
chunk: Union[bytes, bytearray, memoryview],
124+
chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
113125
*,
114126
drain: bool = True,
115127
LIMIT: int = 0x10000,

aiohttp/web_response.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ async def _write_headers(self) -> None:
435435
status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}"
436436
await writer.write_headers(status_line, self._headers)
437437

438-
async def write(self, data: Union[bytes, bytearray, memoryview]) -> None:
438+
async def write(
439+
self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
440+
) -> None:
439441
assert isinstance(
440442
data, (bytes, bytearray, memoryview)
441443
), "data argument must be byte-ish (%r)" % type(data)
@@ -580,7 +582,7 @@ def __init__(
580582
self._zlib_executor = zlib_executor
581583

582584
@property
583-
def body(self) -> Optional[Union[bytes, Payload]]:
585+
def body(self) -> Optional[Union[bytes, bytearray, Payload]]:
584586
return self._body
585587

586588
@body.setter
@@ -654,7 +656,7 @@ async def write_eof(self, data: bytes = b"") -> None:
654656
if self._eof_sent:
655657
return
656658
if self._compressed_body is None:
657-
body: Optional[Union[bytes, Payload]] = self._body
659+
body = self._body
658660
else:
659661
body = self._compressed_body
660662
assert not data, f"data arg is not supported, got {data!r}"

aiohttp/web_ws.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,9 @@ async def receive_json(
606606
data = await self.receive_str(timeout=timeout)
607607
return loads(data)
608608

609-
async def write(self, data: bytes) -> None:
609+
async def write(
610+
self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
611+
) -> None:
610612
raise RuntimeError("Cannot call .write() for websocket")
611613

612614
def __aiter__(self) -> "WebSocketResponse":

0 commit comments

Comments
 (0)