4
4
import hashlib
5
5
import json
6
6
import sys
7
- from typing import Any , Final , Iterable , Optional , Tuple , cast
7
+ from typing import Any , Final , Iterable , Optional , Tuple , Union , cast
8
8
9
9
import attr
10
10
from multidict import CIMultiDict
@@ -99,7 +99,7 @@ def __init__(
99
99
if heartbeat is not None :
100
100
self ._pong_heartbeat = heartbeat / 2.0
101
101
self ._pong_response_cb : Optional [asyncio .TimerHandle ] = None
102
- self ._compress = compress
102
+ self ._compress : Union [ bool , int ] = compress
103
103
self ._max_msg_size = max_msg_size
104
104
self ._ping_task : Optional [asyncio .Task [None ]] = None
105
105
self ._writer_limit = writer_limit
@@ -216,7 +216,7 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
216
216
217
217
def _handshake (
218
218
self , request : BaseRequest
219
- ) -> Tuple ["CIMultiDict[str]" , str , bool , bool ]:
219
+ ) -> Tuple ["CIMultiDict[str]" , Optional [ str ], int , bool ]:
220
220
headers = request .headers
221
221
if "websocket" != headers .get (hdrs .UPGRADE , "" ).lower ().strip ():
222
222
raise HTTPBadRequest (
@@ -234,7 +234,7 @@ def _handshake(
234
234
)
235
235
236
236
# find common sub-protocol between client and server
237
- protocol = None
237
+ protocol : Optional [ str ] = None
238
238
if hdrs .SEC_WEBSOCKET_PROTOCOL in headers :
239
239
req_protocols = [
240
240
str (proto .strip ())
@@ -297,9 +297,9 @@ def _handshake(
297
297
protocol ,
298
298
compress ,
299
299
notakeover ,
300
- ) # type: ignore[return-value]
300
+ )
301
301
302
- def _pre_start (self , request : BaseRequest ) -> Tuple [str , WebSocketWriter ]:
302
+ def _pre_start (self , request : BaseRequest ) -> Tuple [Optional [ str ] , WebSocketWriter ]:
303
303
self ._loop = request ._loop
304
304
305
305
headers , protocol , compress , notakeover = self ._handshake (request )
@@ -321,7 +321,7 @@ def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
321
321
return protocol , writer
322
322
323
323
def _post_start (
324
- self , request : BaseRequest , protocol : str , writer : WebSocketWriter
324
+ self , request : BaseRequest , protocol : Optional [ str ] , writer : WebSocketWriter
325
325
) -> None :
326
326
self ._ws_protocol = protocol
327
327
self ._writer = writer
@@ -332,7 +332,9 @@ def _post_start(
332
332
assert loop is not None
333
333
self ._reader = WebSocketDataQueue (request ._protocol , 2 ** 16 , loop = loop )
334
334
request .protocol .set_parser (
335
- WebSocketReader (self ._reader , self ._max_msg_size , compress = self ._compress )
335
+ WebSocketReader (
336
+ self ._reader , self ._max_msg_size , compress = bool (self ._compress )
337
+ )
336
338
)
337
339
# disable HTTP keepalive for WebSocket
338
340
request .protocol .keep_alive (False )
@@ -360,7 +362,7 @@ def ws_protocol(self) -> Optional[str]:
360
362
return self ._ws_protocol
361
363
362
364
@property
363
- def compress (self ) -> bool :
365
+ def compress (self ) -> Union [ int , bool ] :
364
366
return self ._compress
365
367
366
368
def get_extra_info (self , name : str , default : Any = None ) -> Any :
0 commit comments