Skip to content

Commit af2e62a

Browse files
committed
Improve Django Channels type stubs with better annotations and parameter defaults
1 parent 5668f41 commit af2e62a

File tree

19 files changed

+128
-86
lines changed

19 files changed

+128
-86
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1-
channels.auth.UserLazyObject
1+
# channels.auth.UserLazyObject.DoesNotExist is not present at runtime
2+
# channels.auth.UserLazyObject.MultipleObjectsReturned is not present at runtime
3+
# channels.auth.UserLazyObject@AnnotatedWith is not present at runtime
24
channels.auth.UserLazyObject.*
5+
6+
# database_sync_to_async is implemented as a class instance but stubbed as a function
7+
# for better type inference when used as decorator/function
38
channels.db.database_sync_to_async

stubs/channels/METADATA.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version = "4.*"
1+
version = "4.2.*"
22
upstream_repository = "https://github.com/django/channels"
33
requires = ["django-stubs>=4.2,<5.3", "asgiref"]
44

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
__version__: str
2-
DEFAULT_CHANNEL_LAYER: str
1+
from typing import Final
2+
3+
__version__: Final[str]
4+
DEFAULT_CHANNEL_LAYER: Final[str]

stubs/channels/channels/apps.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from django.apps import AppConfig
22

33
class ChannelsConfig(AppConfig):
4-
name: str = ...
5-
verbose_name: str = ...
4+
name: str = "channels"
5+
verbose_name: str = "Channels"

stubs/channels/channels/auth.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ from .consumer import _ChannelScope
99
from .utils import _ChannelApplication
1010

1111
async def get_user(scope: _ChannelScope) -> AbstractBaseUser | AnonymousUser: ...
12-
async def login(scope: _ChannelScope, user: AbstractBaseUser, backend: BaseBackend | None = ...) -> None: ...
12+
async def login(scope: _ChannelScope, user: AbstractBaseUser, backend: BaseBackend | None = None) -> None: ...
1313
async def logout(scope: _ChannelScope) -> None: ...
1414

15+
# Inherits AbstractBaseUser to improve autocomplete and show this is a lazy proxy for a user.
16+
# At runtime, it's just a LazyObject that wraps the actual user instance.
1517
class UserLazyObject(AbstractBaseUser, LazyObject): ...
1618

1719
class AuthMiddleware(BaseMiddleware):
Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Awaitable
2-
from typing import Any, ClassVar, Protocol, type_check_only
2+
from typing import Any, ClassVar, Protocol, TypedDict, type_check_only
33

44
from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, Scope, WebSocketScope
55
from channels.auth import UserLazyObject
@@ -8,34 +8,49 @@ from channels.layers import BaseChannelLayer
88
from django.contrib.sessions.backends.base import SessionBase
99
from django.utils.functional import LazyObject
1010

11+
# _LazySession is a LazyObject that wraps a SessionBase instance.
12+
# We subclass both for type checking purposes to expose SessionBase attributes,
13+
# and suppress mypy's "misc" error with `# type: ignore[misc]`.
1114
@type_check_only
1215
class _LazySession(SessionBase, LazyObject): # type: ignore[misc]
1316
_wrapped: SessionBase
1417

15-
# Base ASGI Scope definition
18+
@type_check_only
19+
class _URLRoute(TypedDict):
20+
# Values extracted from Django's URLPattern matching,
21+
# passed through ASGI scope routing.
22+
# `args` and `kwargs` are the result of pattern matching against the URL path.
23+
args: tuple[Any, ...]
24+
kwargs: dict[str, Any]
25+
26+
# Channel Scope definition
1627
@type_check_only
1728
class _ChannelScope(WebSocketScope, total=False):
1829
# Channels specific
1930
channel: str
20-
url_route: dict[str, Any]
31+
url_route: _URLRoute
2132
path_remaining: str
2233

2334
# Auth specific
2435
cookies: dict[str, str]
2536
session: _LazySession
2637
user: UserLazyObject | None
2738

39+
# Accepts any ASGI message dict with a required "type" key (str),
40+
# but allows additional arbitrary keys for flexibility.
2841
def get_handler_name(message: dict[str, Any]) -> str: ...
2942
@type_check_only
3043
class _ASGIApplicationProtocol(Protocol):
31-
consumer_class: Any
32-
consumer_initkwargs: dict[str, Any]
44+
consumer_class: AsyncConsumer
45+
46+
# Accepts any initialization kwargs passed to the consumer class.
47+
# Typed as `Any` to allow flexibility in subclass-specific arguments.
48+
consumer_initkwargs: Any
3349

3450
def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> Awaitable[None]: ...
3551

3652
class AsyncConsumer:
37-
_sync: ClassVar[bool] = ...
38-
channel_layer_alias: ClassVar[str] = ...
53+
channel_layer_alias: ClassVar[str]
3954

4055
scope: _ChannelScope
4156
channel_layer: BaseChannelLayer
@@ -50,8 +65,9 @@ class AsyncConsumer:
5065
def as_asgi(cls, **initkwargs: Any) -> _ASGIApplicationProtocol: ...
5166

5267
class SyncConsumer(AsyncConsumer):
53-
_sync: ClassVar[bool] = ...
5468

69+
# Since we're overriding asynchronous methods with synchronous ones,
70+
# we need to use `# type: ignore[override]` to suppress mypy errors.
5571
@database_sync_to_async
5672
def dispatch(self, message: dict[str, Any]) -> None: ... # type: ignore[override]
5773
def send(self, message: dict[str, Any]) -> None: ... # type: ignore[override]

stubs/channels/channels/db.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from asyncio import BaseEventLoop
22
from collections.abc import Callable, Coroutine
3+
from concurrent.futures import ThreadPoolExecutor
34
from typing import Any, TypeVar
45
from typing_extensions import ParamSpec
56

@@ -11,5 +12,11 @@ _R = TypeVar("_R")
1112
class DatabaseSyncToAsync(SyncToAsync[_P, _R]):
1213
def thread_handler(self, loop: BaseEventLoop, *args: Any, **kwargs: Any) -> Any: ...
1314

14-
def database_sync_to_async(func: Callable[_P, _R]) -> Callable[_P, Coroutine[Any, Any, _R]]: ...
15+
# We define `database_sync_to_async` as a function instead of assigning
16+
# `DatabaseSyncToAsync(...)` directly, to preserve both decorator and
17+
# higher-order function behavior with correct type hints.
18+
# A direct assignment would result in incorrect type inference for the wrapped function.
19+
def database_sync_to_async(
20+
func: Callable[_P, _R], thread_sensitive: bool = True, executor: ThreadPoolExecutor | None = None
21+
) -> Callable[_P, Coroutine[Any, Any, _R]]: ...
1522
async def aclose_old_connections() -> None: ...
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from _typeshed import Unused
12
from collections.abc import Iterable
23
from typing import Any
34

@@ -8,12 +9,11 @@ class AsyncHttpConsumer(AsyncConsumer):
89
body: list[bytes]
910
scope: HTTPScope # type: ignore[assignment]
1011

11-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
12-
async def send_headers(self, *, status: int = ..., headers: Iterable[tuple[bytes, bytes]] | None = ...) -> None: ...
13-
async def send_body(self, body: bytes, *, more_body: bool = ...) -> None: ...
12+
def __init__(self, *args: Unused, **kwargs: Unused) -> None: ...
13+
async def send_headers(self, *, status: int = 200, headers: Iterable[tuple[bytes, bytes]] | None = None) -> None: ...
14+
async def send_body(self, body: bytes, *, more_body: bool = False) -> None: ...
1415
async def send_response(self, status: int, body: bytes, **kwargs: Any) -> None: ...
1516
async def handle(self, body: bytes) -> None: ...
1617
async def disconnect(self) -> None: ...
1718
async def http_request(self, message: HTTPRequestEvent) -> None: ...
1819
async def http_disconnect(self, message: HTTPDisconnectEvent) -> None: ...
19-
async def send(self, message: dict[str, Any]) -> None: ... # type: ignore[override]

stubs/channels/channels/generic/websocket.pyi

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@ class WebsocketConsumer(SyncConsumer):
1515
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
1616
def websocket_connect(self, message: WebSocketConnectEvent) -> None: ...
1717
def connect(self) -> None: ...
18-
def accept(self, subprotocol: str | None = ..., headers: list[tuple[str, str]] | None = ...) -> None: ...
18+
def accept(self, subprotocol: str | None = None, headers: list[tuple[str, str]] | None = None) -> None: ...
1919
def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ...
20-
def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ...) -> None: ...
20+
def receive(self, text_data: str | None = None, bytes_data: bytes | None = None) -> None: ...
2121
def send( # type: ignore[override]
22-
self, text_data: str | None = ..., bytes_data: bytes | None = ..., close: bool = ...
22+
self, text_data: str | None = None, bytes_data: bytes | None = None, close: bool = False
2323
) -> None: ...
24-
def close(self, code: int | bool | None = ..., reason: str | None = ...) -> None: ...
24+
def close(self, code: int | bool | None = None, reason: str | None = None) -> None: ...
2525
def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ...
2626
def disconnect(self, code: int) -> None: ...
2727

2828
class JsonWebsocketConsumer(WebsocketConsumer):
29-
def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ..., **kwargs: Any) -> None: ...
29+
def receive(self, text_data: str | None = None, bytes_data: bytes | None = None, **kwargs: Any) -> None: ...
3030
def receive_json(self, content: Any, **kwargs: Any) -> None: ...
31-
def send_json(self, content: Any, close: bool = ...) -> None: ...
31+
def send_json(self, content: Any, close: bool = False) -> None: ...
3232
@classmethod
3333
def decode_json(cls, text_data: str) -> Any: ...
3434
@classmethod
@@ -45,20 +45,20 @@ class AsyncWebsocketConsumer(AsyncConsumer):
4545
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
4646
async def websocket_connect(self, message: WebSocketConnectEvent) -> None: ...
4747
async def connect(self) -> None: ...
48-
async def accept(self, subprotocol: str | None = ..., headers: list[tuple[str, str]] | None = ...) -> None: ...
48+
async def accept(self, subprotocol: str | None = None, headers: list[tuple[str, str]] | None = None) -> None: ...
4949
async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: ...
50-
async def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ...) -> None: ...
50+
async def receive(self, text_data: str | None = None, bytes_data: bytes | None = None) -> None: ...
5151
async def send( # type: ignore[override]
52-
self, text_data: str | None = ..., bytes_data: bytes | None = ..., close: bool = ...
52+
self, text_data: str | None = None, bytes_data: bytes | None = None, close: bool = False
5353
) -> None: ...
54-
async def close(self, code: int | bool | None = ..., reason: str | None = ...) -> None: ...
54+
async def close(self, code: int | bool | None = None, reason: str | None = None) -> None: ...
5555
async def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None: ...
5656
async def disconnect(self, code: int) -> None: ...
5757

5858
class AsyncJsonWebsocketConsumer(AsyncWebsocketConsumer):
59-
async def receive(self, text_data: str | None = ..., bytes_data: bytes | None = ..., **kwargs: Any) -> None: ...
59+
async def receive(self, text_data: str | None = None, bytes_data: bytes | None = None, **kwargs: Any) -> None: ...
6060
async def receive_json(self, content: Any, **kwargs: Any) -> None: ...
61-
async def send_json(self, content: Any, close: bool = ...) -> None: ...
61+
async def send_json(self, content: Any, close: bool = False) -> None: ...
6262
@classmethod
6363
async def decode_json(cls, text_data: str) -> Any: ...
6464
@classmethod

stubs/channels/channels/layers.pyi

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from re import Pattern
3-
from typing import Any, overload
3+
from typing import Any, ClassVar, overload
44
from typing_extensions import TypeAlias, deprecated
55

66
class ChannelLayerManager:
@@ -20,33 +20,33 @@ _ChannelCapacityDict: TypeAlias = dict[_ChannelCapacityPattern, int]
2020
_CompiledChannelCapacity: TypeAlias = list[tuple[Pattern[str], int]]
2121

2222
class BaseChannelLayer:
23-
MAX_NAME_LENGTH: int = ...
23+
MAX_NAME_LENGTH: ClassVar[int] = 100
2424
expiry: int
2525
capacity: int
2626
channel_capacity: _ChannelCapacityDict
2727
channel_name_regex: Pattern[str]
2828
group_name_regex: Pattern[str]
2929
invalid_name_error: str
3030

31-
def __init__(self, expiry: int = ..., capacity: int = ..., channel_capacity: _ChannelCapacityDict | None = ...) -> None: ...
31+
def __init__(self, expiry: int = 60, capacity: int = 100, channel_capacity: _ChannelCapacityDict | None = None) -> None: ...
3232
def compile_capacities(self, channel_capacity: _ChannelCapacityDict) -> _CompiledChannelCapacity: ...
3333
def get_capacity(self, channel: str) -> int: ...
3434
@overload
3535
def match_type_and_length(self, name: str) -> bool: ...
3636
@overload
3737
def match_type_and_length(self, name: Any) -> bool: ...
3838
@overload
39-
def require_valid_channel_name(self, name: str, receive: bool = ...) -> bool: ...
39+
def require_valid_channel_name(self, name: str, receive: bool = False) -> bool: ...
4040
@overload
41-
def require_valid_channel_name(self, name: Any, receive: bool = ...) -> bool: ...
41+
def require_valid_channel_name(self, name: Any, receive: bool = False) -> bool: ...
4242
@overload
4343
def require_valid_group_name(self, name: str) -> bool: ...
4444
@overload
4545
def require_valid_group_name(self, name: Any) -> bool: ...
4646
@overload
47-
def valid_channel_names(self, names: list[str], receive: bool = ...) -> bool: ...
47+
def valid_channel_names(self, names: list[str], receive: bool = False) -> bool: ...
4848
@overload
49-
def valid_channel_names(self, names: list[Any], receive: bool = ...) -> bool: ...
49+
def valid_channel_names(self, names: list[Any], receive: bool = False) -> bool: ...
5050
def non_local_name(self, name: str) -> str: ...
5151
async def send(self, channel: str, message: dict[str, Any]) -> None: ...
5252
async def receive(self, channel: str) -> dict[str, Any]: ...
@@ -56,7 +56,7 @@ class BaseChannelLayer:
5656
async def group_discard(self, group: str, channel: str) -> None: ...
5757
async def group_send(self, group: str, message: dict[str, Any]) -> None: ...
5858
@deprecated("Use require_valid_channel_name instead.")
59-
def valid_channel_name(self, channel_name: str, receive: bool = ...) -> bool: ...
59+
def valid_channel_name(self, channel_name: str, receive: bool = False) -> bool: ...
6060
@deprecated("Use require_valid_group_name instead.")
6161
def valid_group_name(self, group_name: str) -> bool: ...
6262

@@ -69,9 +69,9 @@ class InMemoryChannelLayer(BaseChannelLayer):
6969

7070
def __init__(
7171
self,
72-
expiry: int = ...,
73-
group_expiry: int = ...,
74-
capacity: int = ...,
72+
expiry: int = 60,
73+
group_expiry: int = 86400,
74+
capacity: int = 100,
7575
channel_capacity: _ChannelCapacityDict | None = ...,
7676
**kwargs: Any,
7777
) -> None: ...
@@ -80,7 +80,7 @@ class InMemoryChannelLayer(BaseChannelLayer):
8080

8181
async def send(self, channel: str, message: dict[str, Any]) -> None: ...
8282
async def receive(self, channel: str) -> dict[str, Any]: ...
83-
async def new_channel(self, prefix: str = ...) -> str: ...
83+
async def new_channel(self, prefix: str = "specific.") -> str: ...
8484
async def flush(self) -> None: ...
8585
async def close(self) -> None: ...
8686
async def group_add(self, group: str, channel: str) -> None: ...

0 commit comments

Comments
 (0)