Skip to content

Commit 4c4f3f2

Browse files
authored
[PR #9800/f0e26eb backport][3.11] Fix incorrect typing on WebSocketResponse._handshake (#9811)
1 parent 61c7eef commit 4c4f3f2

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

aiohttp/web_ws.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import hashlib
55
import json
66
import sys
7-
from typing import Any, Final, Iterable, Optional, Tuple, cast
7+
from typing import Any, Final, Iterable, Optional, Tuple, Union, cast
88

99
import attr
1010
from multidict import CIMultiDict
@@ -99,7 +99,7 @@ def __init__(
9999
if heartbeat is not None:
100100
self._pong_heartbeat = heartbeat / 2.0
101101
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
102-
self._compress = compress
102+
self._compress: Union[bool, int] = compress
103103
self._max_msg_size = max_msg_size
104104
self._ping_task: Optional[asyncio.Task[None]] = None
105105
self._writer_limit = writer_limit
@@ -216,7 +216,7 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
216216

217217
def _handshake(
218218
self, request: BaseRequest
219-
) -> Tuple["CIMultiDict[str]", str, bool, bool]:
219+
) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]:
220220
headers = request.headers
221221
if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
222222
raise HTTPBadRequest(
@@ -234,7 +234,7 @@ def _handshake(
234234
)
235235

236236
# find common sub-protocol between client and server
237-
protocol = None
237+
protocol: Optional[str] = None
238238
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
239239
req_protocols = [
240240
str(proto.strip())
@@ -297,9 +297,9 @@ def _handshake(
297297
protocol,
298298
compress,
299299
notakeover,
300-
) # type: ignore[return-value]
300+
)
301301

302-
def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
302+
def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWriter]:
303303
self._loop = request._loop
304304

305305
headers, protocol, compress, notakeover = self._handshake(request)
@@ -321,7 +321,7 @@ def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
321321
return protocol, writer
322322

323323
def _post_start(
324-
self, request: BaseRequest, protocol: str, writer: WebSocketWriter
324+
self, request: BaseRequest, protocol: Optional[str], writer: WebSocketWriter
325325
) -> None:
326326
self._ws_protocol = protocol
327327
self._writer = writer
@@ -332,7 +332,9 @@ def _post_start(
332332
assert loop is not None
333333
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
334334
request.protocol.set_parser(
335-
WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
335+
WebSocketReader(
336+
self._reader, self._max_msg_size, compress=bool(self._compress)
337+
)
336338
)
337339
# disable HTTP keepalive for WebSocket
338340
request.protocol.keep_alive(False)
@@ -360,7 +362,7 @@ def ws_protocol(self) -> Optional[str]:
360362
return self._ws_protocol
361363

362364
@property
363-
def compress(self) -> bool:
365+
def compress(self) -> Union[int, bool]:
364366
return self._compress
365367

366368
def get_extra_info(self, name: str, default: Any = None) -> Any:

0 commit comments

Comments
 (0)