Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .audit/bblommers_websocket-server-typing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- [x] I did **not** use any AI-assistance tools to help create this pull request.
- [ ] I **did** use AI-assistance tools to *help* create this pull request.
- [x] I have read, understood and followed the projects' [AI Policy](https://github.com/crossbario/autobahn-python/blob/main/AI_POLICY.md) when creating code, documentation etc. for this pull request.

Submitted by: @bblommers
Date: 2026-01-07
Related issue(s): #1839
Branch: bblommers:websocket-server-typing
4 changes: 1 addition & 3 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -702,17 +702,15 @@ check-typing venv="": (install venv)
--ignore unresolved-attribute \
--ignore unresolved-reference \
--ignore possibly-missing-attribute \
--ignore possibly-missing-import \
--ignore call-non-callable \
--ignore invalid-assignment \
--ignore invalid-argument-type \
--ignore invalid-return-type \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the changes in autobahn/wamp, this check now passes

--ignore invalid-method-override \
--ignore invalid-type-form \
--ignore unsupported-operator \
--ignore too-many-positional-arguments \
--ignore unknown-argument \
--ignore non-subscriptable \
--ignore not-subscriptable \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor typo to get the CI to pass

--ignore not-iterable \
--ignore no-matching-overload \
--ignore conflicting-declarations \
Expand Down
58 changes: 30 additions & 28 deletions src/autobahn/twisted/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
#
###############################################################################

from __future__ import annotations

from base64 import b64decode, b64encode
from typing import Optional
from typing import Any

import txaio
from zope.interface import implementer
Expand Down Expand Up @@ -79,14 +81,14 @@
)


def create_client_agent(reactor):
def create_client_agent(reactor) -> "_TwistedWebSocketClientAgent":
"""
:returns: an instance implementing IWebSocketClientAgent
"""
return _TwistedWebSocketClientAgent(reactor)


def check_transport_config(transport_config):
def check_transport_config(transport_config: str) -> None:
"""
raises a ValueError if `transport_config` is invalid
"""
Expand All @@ -107,7 +109,7 @@ def check_transport_config(transport_config):
return None


def check_client_options(options):
def check_client_options(options: dict[str, Any]) -> None:
"""
raises a ValueError if `options` is invalid
"""
Expand Down Expand Up @@ -261,10 +263,10 @@ class WebSocketAdapterProtocol(twisted.internet.protocol.Protocol):

log = txaio.make_logger()

peer: Optional[str] = None
is_server: Optional[bool] = None
peer: str | None = None
is_server: bool | None = None

def connectionMade(self):
def connectionMade(self) -> None:
# Twisted networking framework entry point, called by Twisted
# when the connection is established (either a client or a server)

Expand Down Expand Up @@ -296,7 +298,7 @@ def connectionMade(self):
peer=hlval(self.peer),
)

def connectionLost(self, reason: Failure = connectionDone):
def connectionLost(self, reason: Failure = connectionDone) -> None:
# Twisted networking framework entry point, called by Twisted
# when the connection is lost (either a client or a server)

Expand Down Expand Up @@ -352,7 +354,7 @@ def connectionLost(self, reason: Failure = connectionDone):
reason=reason,
)

def dataReceived(self, data: bytes):
def dataReceived(self, data: bytes) -> None:
self.log.debug(
'{func} received {data_len} bytes for peer="{peer}"',
func=hltype(self.dataReceived),
Expand All @@ -363,14 +365,14 @@ def dataReceived(self, data: bytes):
# bytes received from Twisted, forward to the networking framework independent code for websocket
self._dataReceived(data)

def _closeConnection(self, abort=False):
def _closeConnection(self, abort: bool=False) -> None:
if abort and hasattr(self.transport, "abortConnection"):
self.transport.abortConnection()
else:
# e.g. ProcessProtocol lacks abortConnection()
self.transport.loseConnection()

def _onOpen(self):
def _onOpen(self) -> None:
if self._transport_details.is_secure:
# now that the TLS opening handshake is complete, the actual TLS channel ID
# will be available. make sure to set it!
Expand All @@ -383,37 +385,37 @@ def _onOpen(self):

self.onOpen()

def _onMessageBegin(self, isBinary):
def _onMessageBegin(self, isBinary: bool) -> None:
self.onMessageBegin(isBinary)

def _onMessageFrameBegin(self, length):
def _onMessageFrameBegin(self, length: int) -> None:
self.onMessageFrameBegin(length)

def _onMessageFrameData(self, payload):
def _onMessageFrameData(self, payload) -> None:
self.onMessageFrameData(payload)

def _onMessageFrameEnd(self):
def _onMessageFrameEnd(self) -> None:
self.onMessageFrameEnd()

def _onMessageFrame(self, payload):
def _onMessageFrame(self, payload) -> None:
self.onMessageFrame(payload)

def _onMessageEnd(self):
def _onMessageEnd(self) -> None:
self.onMessageEnd()

def _onMessage(self, payload, isBinary):
def _onMessage(self, payload, isBinary: bool) -> None:
self.onMessage(payload, isBinary)

def _onPing(self, payload):
def _onPing(self, payload) -> None:
self.onPing(payload)

def _onPong(self, payload):
def _onPong(self, payload) -> None:
self.onPong(payload)

def _onClose(self, wasClean, code, reason):
def _onClose(self, wasClean: bool, code, reason) -> None:
self.onClose(wasClean, code, reason)

def registerProducer(self, producer, streaming):
def registerProducer(self, producer, streaming) -> None:
"""
Register a Twisted producer with this protocol.

Expand All @@ -424,7 +426,7 @@ def registerProducer(self, producer, streaming):
"""
self.transport.registerProducer(producer, streaming)

def unregisterProducer(self):
def unregisterProducer(self) -> None:
"""
Unregister Twisted producer with this protocol.
"""
Expand Down Expand Up @@ -608,10 +610,10 @@ def onConnect(self, requestOrResponse):
# should not arrive here
raise Exception("logic error")

def onOpen(self):
def onOpen(self) -> None:
self._proto.connectionMade()

def onMessage(self, payload, isBinary):
def onMessage(self, payload: bytes, isBinary: bool) -> None:
if isBinary != self._binaryMode:
self._fail_connection(
protocol.WebSocketProtocol.CLOSE_STATUS_CODE_UNSUPPORTED_DATA,
Expand All @@ -632,7 +634,7 @@ def onMessage(self, payload, isBinary):
def onClose(self, wasClean, code, reason):
self._proto.connectionLost(None)

def write(self, data):
def write(self, data: bytes) -> None:
# part of ITransport
assert type(data) == bytes
if self._binaryMode:
Expand All @@ -641,12 +643,12 @@ def write(self, data):
data = b64encode(data)
self.sendMessage(data, isBinary=False)

def writeSequence(self, data):
def writeSequence(self, data: bytes) -> None:
# part of ITransport
for d in data:
self.write(d)

def loseConnection(self):
def loseConnection(self) -> None:
# part of ITransport
self.sendClose()

Expand Down
34 changes: 30 additions & 4 deletions src/autobahn/wamp/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
#
###############################################################################

from __future__ import annotations

import binascii
import re
import textwrap
from pprint import pformat
from typing import Any, Dict, Optional
from typing import Any, Literal, overload

import autobahn
from autobahn.util import hlval
Expand Down Expand Up @@ -243,7 +245,7 @@ def b2a(data, max_len=40):
return s


def identify_realm_name_category(value: Any) -> Optional[str]:
def identify_realm_name_category(value: Any) -> str | None:
"""
Identify the real name category of the given value:

Expand Down Expand Up @@ -272,14 +274,38 @@ def identify_realm_name_category(value: Any) -> Optional[str]:
return None


@overload
def check_or_raise_uri(
value: Any,
message: str,
strict: bool,
allow_empty_components: bool,
allow_last_empty: bool,
allow_none: Literal[True],
) -> str | None:
pass


@overload
def check_or_raise_uri(
value: Any,
message: str = "WAMP message invalid",
strict: bool = False,
allow_empty_components: bool = False,
allow_last_empty: bool = False,
allow_none: bool = False,
allow_none: Literal[False] = False,
) -> str:
pass


def check_or_raise_uri(
value: Any,
message: str = "WAMP message invalid",
strict: bool = False,
allow_empty_components: bool = False,
allow_last_empty: bool = False,
allow_none: bool = False,
) -> str | None:
"""
Check a value for being a valid WAMP URI.

Expand Down Expand Up @@ -408,7 +434,7 @@ def check_or_raise_id(value: Any, message: str = "WAMP message invalid") -> int:

def check_or_raise_extra(
value: Any, message: str = "WAMP message invalid"
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Check a value for being a valid WAMP extra dictionary.

Expand Down
Loading