Skip to content

Commit a81a7af

Browse files
pgjonesKriechi
authored andcommitted
Fix typing issues
These are now preset as h11 is typed.
1 parent 8f86c35 commit a81a7af

File tree

4 files changed

+25
-30
lines changed

4 files changed

+25
-30
lines changed

src/wsproto/handshake.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def receive_data(self, data: Optional[bytes]) -> None:
119119
120120
:param bytes data: Data received from the WebSocket peer.
121121
"""
122-
self._h11_connection.receive_data(data)
122+
self._h11_connection.receive_data(data or b"")
123123
while True:
124124
try:
125125
event = self._h11_connection.next_event()
@@ -141,7 +141,7 @@ def receive_data(self, data: Optional[bytes]) -> None:
141141
else:
142142
self._events.append(
143143
RejectConnection(
144-
headers=event.headers,
144+
headers=list(event.headers),
145145
status_code=event.status_code,
146146
has_body=False,
147147
)
@@ -151,7 +151,7 @@ def receive_data(self, data: Optional[bytes]) -> None:
151151
self._state = ConnectionState.REJECTING
152152
self._events.append(
153153
RejectConnection(
154-
headers=event.headers,
154+
headers=list(event.headers),
155155
status_code=event.status_code,
156156
has_body=True,
157157
)
@@ -286,36 +286,36 @@ def _accept(self, event: AcceptConnection) -> bytes:
286286
event.extensions,
287287
)
288288
self._state = ConnectionState.OPEN
289-
return self._h11_connection.send(response) # type: ignore[no-any-return]
289+
return self._h11_connection.send(response) or b""
290290

291291
def _reject(self, event: RejectConnection) -> bytes:
292292
if self.state != ConnectionState.CONNECTING:
293293
raise LocalProtocolError(
294294
"Connection cannot be rejected in state %s" % self.state
295295
)
296296

297-
headers = event.headers
297+
headers = list(event.headers)
298298
if not event.has_body:
299299
headers.append((b"content-length", b"0"))
300300
response = h11.Response(status_code=event.status_code, headers=headers)
301-
data = self._h11_connection.send(response)
301+
data = self._h11_connection.send(response) or b""
302302
self._state = ConnectionState.REJECTING
303303
if not event.has_body:
304-
data += self._h11_connection.send(h11.EndOfMessage())
304+
data += self._h11_connection.send(h11.EndOfMessage()) or b""
305305
self._state = ConnectionState.CLOSED
306-
return data # type: ignore[no-any-return]
306+
return data
307307

308308
def _send_reject_data(self, event: RejectData) -> bytes:
309309
if self.state != ConnectionState.REJECTING:
310310
raise LocalProtocolError(
311311
f"Cannot send rejection data in state {self.state}"
312312
)
313313

314-
data = self._h11_connection.send(h11.Data(data=event.data))
314+
data = self._h11_connection.send(h11.Data(data=event.data)) or b""
315315
if event.body_finished:
316-
data += self._h11_connection.send(h11.EndOfMessage())
316+
data += self._h11_connection.send(h11.EndOfMessage()) or b""
317317
self._state = ConnectionState.CLOSED
318-
return data # type: ignore[no-any-return]
318+
return data
319319

320320
# Client mode methods
321321

@@ -360,7 +360,7 @@ def _initiate_connection(self, request: Request) -> bytes:
360360
target=request.target.encode("ascii"),
361361
headers=headers + request.extra_headers,
362362
)
363-
return self._h11_connection.send(upgrade) # type: ignore[no-any-return]
363+
return self._h11_connection.send(upgrade) or b""
364364

365365
def _establish_client_connection(
366366
self, event: h11.InformationalResponse
@@ -387,7 +387,7 @@ def _establish_client_connection(
387387
accept = value
388388
continue # Skip appending to headers
389389
elif name == b"sec-websocket-protocol":
390-
subprotocol = value
390+
subprotocol = value.decode("ascii")
391391
continue # Skip appending to headers
392392
elif name == b"upgrade":
393393
upgrade = value
@@ -408,7 +408,6 @@ def _establish_client_connection(
408408
if accept != accept_token:
409409
raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
410410
if subprotocol is not None:
411-
subprotocol = subprotocol.decode("ascii")
412411
if subprotocol not in self._initiating_request.subprotocols:
413412
raise RemoteProtocolError(
414413
f"unrecognized subprotocol {subprotocol}",

src/wsproto/utilities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import base64
88
import hashlib
99
import os
10-
from typing import Dict, List, Optional
10+
from typing import Dict, List, Optional, Union
11+
12+
from h11._headers import Headers as H11Headers
1113

1214
from .events import Event
1315
from .typing import Headers
@@ -51,7 +53,7 @@ def __init__(self, message: str, event_hint: Optional[Event] = None) -> None:
5153

5254

5355
# Some convenience utilities for working with HTTP headers
54-
def normed_header_dict(h11_headers: Headers) -> Dict[bytes, bytes]:
56+
def normed_header_dict(h11_headers: Union[Headers, H11Headers]) -> Dict[bytes, bytes]:
5557
# This mangles Set-Cookie headers. But it happens that we don't care about
5658
# any of those, so it's OK. For every other HTTP header, if there are
5759
# multiple instances then you're allowed to join them together with

test/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import cast, List, Optional
22

33
import h11
44
import pytest
@@ -27,7 +27,7 @@ def _make_connection_request(request: Request) -> h11.Request:
2727
client = WSConnection(CLIENT)
2828
server = h11.Connection(h11.SERVER)
2929
server.receive_data(client.send(request))
30-
return server.next_event()
30+
return cast(h11.Request, server.next_event())
3131

3232

3333
def test_connection_request() -> None:
@@ -114,7 +114,7 @@ def test_connection_send_state() -> None:
114114
)
115115
)
116116
)
117-
headers = normed_header_dict(server.next_event().headers)
117+
headers = normed_header_dict(cast(h11.Request, server.next_event()).headers)
118118
response = h11.InformationalResponse(
119119
status_code=101,
120120
headers=[
@@ -158,7 +158,7 @@ def _make_handshake(
158158
)
159159
)
160160
)
161-
request = server.next_event()
161+
request = cast(h11.Request, server.next_event())
162162
if auto_accept_key:
163163
full_request_headers = normed_header_dict(request.headers)
164164
response_headers.append(

test/test_server.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55

66
from wsproto import WSConnection
77
from wsproto.connection import SERVER
8-
from wsproto.events import (
9-
AcceptConnection,
10-
Event,
11-
RejectConnection,
12-
RejectData,
13-
Request,
14-
)
8+
from wsproto.events import AcceptConnection, RejectConnection, RejectData, Request
159
from wsproto.extensions import Extension
1610
from wsproto.typing import Headers
1711
from wsproto.utilities import (
@@ -199,7 +193,7 @@ def _make_handshake(
199193
)
200194
)
201195
event = client.next_event()
202-
return event, nonce
196+
return cast(h11.InformationalResponse, event), nonce
203197

204198

205199
def test_handshake() -> None:
@@ -292,7 +286,7 @@ def test_protocol_error() -> None:
292286

293287
def _make_handshake_rejection(
294288
status_code: int, body: Optional[bytes] = None
295-
) -> List[Event]:
289+
) -> List[h11.Event]:
296290
client = h11.Connection(h11.CLIENT)
297291
server = WSConnection(SERVER)
298292
nonce = generate_nonce()
@@ -327,7 +321,7 @@ def _make_handshake_rejection(
327321
events = []
328322
while True:
329323
event = client.next_event()
330-
events.append(event)
324+
events.append(cast(h11.Event, event))
331325
if isinstance(event, h11.EndOfMessage):
332326
return events
333327

0 commit comments

Comments
 (0)