Skip to content

Commit b006e90

Browse files
authored
Autobahn WebSocket Protocol - Improve typing (#1838)
* Autobahn WebSocket Procol - Improve typing * Add AI disclosure * Increase conformity to style guide
1 parent 5e4f7e0 commit b006e90

File tree

6 files changed

+270
-245
lines changed

6 files changed

+270
-245
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- [x] I did **not** use any AI-assistance tools to help create this pull request.
2+
- [ ] I **did** use AI-assistance tools to *help* create this pull request.
3+
- [x] I have read, understood and followed the projects' [AI Policy](https://github.com/crossbario/autobahn-python/blob/main/AI_POLICY.md) when creating code, documentation etc. for this pull request.
4+
5+
Submitted by: @bblommers
6+
Date: 2026-01-07
7+
Related issue(s): #1839
8+
Branch: bblommers:websocket-server-typing

justfile

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,17 +702,15 @@ check-typing venv="": (install venv)
702702
--ignore unresolved-attribute \
703703
--ignore unresolved-reference \
704704
--ignore possibly-missing-attribute \
705-
--ignore possibly-missing-import \
706705
--ignore call-non-callable \
707706
--ignore invalid-assignment \
708707
--ignore invalid-argument-type \
709-
--ignore invalid-return-type \
710708
--ignore invalid-method-override \
711709
--ignore invalid-type-form \
712710
--ignore unsupported-operator \
713711
--ignore too-many-positional-arguments \
714712
--ignore unknown-argument \
715-
--ignore non-subscriptable \
713+
--ignore not-subscriptable \
716714
--ignore not-iterable \
717715
--ignore no-matching-overload \
718716
--ignore conflicting-declarations \

src/autobahn/twisted/websocket.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
#
2525
###############################################################################
2626

27+
from __future__ import annotations
28+
2729
from base64 import b64decode, b64encode
28-
from typing import Optional
30+
from typing import Any
2931

3032
import txaio
3133
from zope.interface import implementer
@@ -79,14 +81,14 @@
7981
)
8082

8183

82-
def create_client_agent(reactor):
84+
def create_client_agent(reactor) -> "_TwistedWebSocketClientAgent":
8385
"""
8486
:returns: an instance implementing IWebSocketClientAgent
8587
"""
8688
return _TwistedWebSocketClientAgent(reactor)
8789

8890

89-
def check_transport_config(transport_config):
91+
def check_transport_config(transport_config: str) -> None:
9092
"""
9193
raises a ValueError if `transport_config` is invalid
9294
"""
@@ -107,7 +109,7 @@ def check_transport_config(transport_config):
107109
return None
108110

109111

110-
def check_client_options(options):
112+
def check_client_options(options: dict[str, Any]) -> None:
111113
"""
112114
raises a ValueError if `options` is invalid
113115
"""
@@ -261,10 +263,10 @@ class WebSocketAdapterProtocol(twisted.internet.protocol.Protocol):
261263

262264
log = txaio.make_logger()
263265

264-
peer: Optional[str] = None
265-
is_server: Optional[bool] = None
266+
peer: str | None = None
267+
is_server: bool | None = None
266268

267-
def connectionMade(self):
269+
def connectionMade(self) -> None:
268270
# Twisted networking framework entry point, called by Twisted
269271
# when the connection is established (either a client or a server)
270272

@@ -296,7 +298,7 @@ def connectionMade(self):
296298
peer=hlval(self.peer),
297299
)
298300

299-
def connectionLost(self, reason: Failure = connectionDone):
301+
def connectionLost(self, reason: Failure = connectionDone) -> None:
300302
# Twisted networking framework entry point, called by Twisted
301303
# when the connection is lost (either a client or a server)
302304

@@ -352,7 +354,7 @@ def connectionLost(self, reason: Failure = connectionDone):
352354
reason=reason,
353355
)
354356

355-
def dataReceived(self, data: bytes):
357+
def dataReceived(self, data: bytes) -> None:
356358
self.log.debug(
357359
'{func} received {data_len} bytes for peer="{peer}"',
358360
func=hltype(self.dataReceived),
@@ -363,14 +365,14 @@ def dataReceived(self, data: bytes):
363365
# bytes received from Twisted, forward to the networking framework independent code for websocket
364366
self._dataReceived(data)
365367

366-
def _closeConnection(self, abort=False):
368+
def _closeConnection(self, abort: bool=False) -> None:
367369
if abort and hasattr(self.transport, "abortConnection"):
368370
self.transport.abortConnection()
369371
else:
370372
# e.g. ProcessProtocol lacks abortConnection()
371373
self.transport.loseConnection()
372374

373-
def _onOpen(self):
375+
def _onOpen(self) -> None:
374376
if self._transport_details.is_secure:
375377
# now that the TLS opening handshake is complete, the actual TLS channel ID
376378
# will be available. make sure to set it!
@@ -383,37 +385,37 @@ def _onOpen(self):
383385

384386
self.onOpen()
385387

386-
def _onMessageBegin(self, isBinary):
388+
def _onMessageBegin(self, isBinary: bool) -> None:
387389
self.onMessageBegin(isBinary)
388390

389-
def _onMessageFrameBegin(self, length):
391+
def _onMessageFrameBegin(self, length: int) -> None:
390392
self.onMessageFrameBegin(length)
391393

392-
def _onMessageFrameData(self, payload):
394+
def _onMessageFrameData(self, payload) -> None:
393395
self.onMessageFrameData(payload)
394396

395-
def _onMessageFrameEnd(self):
397+
def _onMessageFrameEnd(self) -> None:
396398
self.onMessageFrameEnd()
397399

398-
def _onMessageFrame(self, payload):
400+
def _onMessageFrame(self, payload) -> None:
399401
self.onMessageFrame(payload)
400402

401-
def _onMessageEnd(self):
403+
def _onMessageEnd(self) -> None:
402404
self.onMessageEnd()
403405

404-
def _onMessage(self, payload, isBinary):
406+
def _onMessage(self, payload, isBinary: bool) -> None:
405407
self.onMessage(payload, isBinary)
406408

407-
def _onPing(self, payload):
409+
def _onPing(self, payload) -> None:
408410
self.onPing(payload)
409411

410-
def _onPong(self, payload):
412+
def _onPong(self, payload) -> None:
411413
self.onPong(payload)
412414

413-
def _onClose(self, wasClean, code, reason):
415+
def _onClose(self, wasClean: bool, code, reason) -> None:
414416
self.onClose(wasClean, code, reason)
415417

416-
def registerProducer(self, producer, streaming):
418+
def registerProducer(self, producer, streaming) -> None:
417419
"""
418420
Register a Twisted producer with this protocol.
419421
@@ -424,7 +426,7 @@ def registerProducer(self, producer, streaming):
424426
"""
425427
self.transport.registerProducer(producer, streaming)
426428

427-
def unregisterProducer(self):
429+
def unregisterProducer(self) -> None:
428430
"""
429431
Unregister Twisted producer with this protocol.
430432
"""
@@ -608,10 +610,10 @@ def onConnect(self, requestOrResponse):
608610
# should not arrive here
609611
raise Exception("logic error")
610612

611-
def onOpen(self):
613+
def onOpen(self) -> None:
612614
self._proto.connectionMade()
613615

614-
def onMessage(self, payload, isBinary):
616+
def onMessage(self, payload: bytes, isBinary: bool) -> None:
615617
if isBinary != self._binaryMode:
616618
self._fail_connection(
617619
protocol.WebSocketProtocol.CLOSE_STATUS_CODE_UNSUPPORTED_DATA,
@@ -632,7 +634,7 @@ def onMessage(self, payload, isBinary):
632634
def onClose(self, wasClean, code, reason):
633635
self._proto.connectionLost(None)
634636

635-
def write(self, data):
637+
def write(self, data: bytes) -> None:
636638
# part of ITransport
637639
assert type(data) == bytes
638640
if self._binaryMode:
@@ -641,12 +643,12 @@ def write(self, data):
641643
data = b64encode(data)
642644
self.sendMessage(data, isBinary=False)
643645

644-
def writeSequence(self, data):
646+
def writeSequence(self, data: bytes) -> None:
645647
# part of ITransport
646648
for d in data:
647649
self.write(d)
648650

649-
def loseConnection(self):
651+
def loseConnection(self) -> None:
650652
# part of ITransport
651653
self.sendClose()
652654

src/autobahn/wamp/message.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
#
2525
###############################################################################
2626

27+
from __future__ import annotations
28+
2729
import binascii
2830
import re
2931
import textwrap
3032
from pprint import pformat
31-
from typing import Any, Dict, Optional
33+
from typing import Any, Literal, overload
3234

3335
import autobahn
3436
from autobahn.util import hlval
@@ -243,7 +245,7 @@ def b2a(data, max_len=40):
243245
return s
244246

245247

246-
def identify_realm_name_category(value: Any) -> Optional[str]:
248+
def identify_realm_name_category(value: Any) -> str | None:
247249
"""
248250
Identify the real name category of the given value:
249251
@@ -272,14 +274,38 @@ def identify_realm_name_category(value: Any) -> Optional[str]:
272274
return None
273275

274276

277+
@overload
278+
def check_or_raise_uri(
279+
value: Any,
280+
message: str,
281+
strict: bool,
282+
allow_empty_components: bool,
283+
allow_last_empty: bool,
284+
allow_none: Literal[True],
285+
) -> str | None:
286+
pass
287+
288+
289+
@overload
275290
def check_or_raise_uri(
276291
value: Any,
277292
message: str = "WAMP message invalid",
278293
strict: bool = False,
279294
allow_empty_components: bool = False,
280295
allow_last_empty: bool = False,
281-
allow_none: bool = False,
296+
allow_none: Literal[False] = False,
282297
) -> str:
298+
pass
299+
300+
301+
def check_or_raise_uri(
302+
value: Any,
303+
message: str = "WAMP message invalid",
304+
strict: bool = False,
305+
allow_empty_components: bool = False,
306+
allow_last_empty: bool = False,
307+
allow_none: bool = False,
308+
) -> str | None:
283309
"""
284310
Check a value for being a valid WAMP URI.
285311
@@ -408,7 +434,7 @@ def check_or_raise_id(value: Any, message: str = "WAMP message invalid") -> int:
408434

409435
def check_or_raise_extra(
410436
value: Any, message: str = "WAMP message invalid"
411-
) -> Dict[str, Any]:
437+
) -> dict[str, Any]:
412438
"""
413439
Check a value for being a valid WAMP extra dictionary.
414440

0 commit comments

Comments
 (0)