Skip to content

Commit 198754e

Browse files
committed
FIX
1 parent a5a1b26 commit 198754e

File tree

5 files changed

+44
-38
lines changed

5 files changed

+44
-38
lines changed

channels_redis/core.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from redis import asyncio as aioredis
1111

12-
from channels.exceptions import ChannelFull
13-
from channels.layers import BaseChannelLayer
12+
from channels.exceptions import ChannelFull # type: ignore[import-untyped]
13+
from channels.layers import BaseChannelLayer # type: ignore[import-untyped]
1414

1515
from .serializers import registry
1616
from .utils import (
@@ -38,12 +38,10 @@ class ChannelLock:
3838
"""
3939

4040
def __init__(self) -> None:
41-
self.locks: collections.defaultdict[str, asyncio.Lock] = (
42-
collections.defaultdict(asyncio.Lock)
43-
)
44-
self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict(
45-
int
41+
self.locks: typing.DefaultDict[str, asyncio.Lock] = collections.defaultdict(
42+
asyncio.Lock
4643
)
44+
self.wait_counts: typing.DefaultDict[str, int] = collections.defaultdict(int)
4745

4846
async def acquire(self, channel: str) -> bool:
4947
"""
@@ -69,7 +67,7 @@ def release(self, channel: str) -> None:
6967
del self.wait_counts[channel]
7068

7169

72-
class BoundedQueue(asyncio.Queue[typing.Any]):
70+
class BoundedQueue(asyncio.Queue):
7371
def put_nowait(self, item: typing.Any) -> None:
7472
if self.full():
7573
# see: https://github.com/django/channels_redis/issues/212
@@ -157,11 +155,11 @@ def __init__(
157155
# Event loop they are trying to receive on
158156
self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None
159157
# Buffered messages by process-local channel name
160-
self.receive_buffer: collections.defaultdict[str, BoundedQueue] = (
158+
self.receive_buffer: typing.DefaultDict[str, BoundedQueue] = (
161159
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
162160
)
163161
# Detached channel cleanup tasks
164-
self.receive_cleaners: typing.List[asyncio.Task[typing.Any]] = []
162+
self.receive_cleaners: typing.List["asyncio.Task[typing.Any]"] = []
165163
# Per-channel cleanup locks to prevent a receive starting and moving
166164
# a message back into the main queue before its cleanup has completed
167165
self.receive_clean_locks = ChannelLock()
@@ -240,7 +238,7 @@ async def _brpop_with_clean(
240238
connection = self.connection(index)
241239
# Cancellation here doesn't matter, we're not doing anything destructive
242240
# and the script executes atomically...
243-
await connection.eval(cleanup_script, 0, channel, backup_queue)
241+
await connection.eval(cleanup_script, 0, channel, backup_queue) # type: ignore[misc]
244242
# ...and it doesn't matter here either, the message will be safe in the backup.
245243
result = await connection.bzpopmin(channel, timeout=timeout)
246244

@@ -271,9 +269,9 @@ async def receive(self, channel: str) -> typing.Any:
271269
assert self.valid_channel_name(channel)
272270
if "!" in channel:
273271
real_channel = self.non_local_name(channel)
274-
assert real_channel.endswith(self.client_prefix + "!"), (
275-
"Wrong client prefix"
276-
)
272+
assert real_channel.endswith(
273+
self.client_prefix + "!"
274+
), "Wrong client prefix"
277275
# Enter receiving section
278276
loop = asyncio.get_running_loop()
279277
self.receive_count += 1
@@ -293,7 +291,7 @@ async def receive(self, channel: str) -> typing.Any:
293291
message = None
294292
while self.receive_buffer[channel].empty():
295293
_tasks = [
296-
self.receive_lock.acquire(),
294+
self.receive_lock.acquire(), # type: ignore[union-attr]
297295
self.receive_buffer[channel].get(),
298296
]
299297
tasks = [asyncio.ensure_future(task) for task in _tasks]
@@ -312,7 +310,7 @@ async def receive(self, channel: str) -> typing.Any:
312310
if not task.cancel():
313311
assert task.done()
314312
if task.result() is True:
315-
self.receive_lock.release()
313+
self.receive_lock.release() # type: ignore[union-attr]
316314

317315
raise
318316

@@ -335,7 +333,7 @@ async def receive(self, channel: str) -> typing.Any:
335333
if message or exception:
336334
if token:
337335
# We will not be receving as we already have the message.
338-
self.receive_lock.release()
336+
self.receive_lock.release() # type: ignore[union-attr]
339337

340338
if exception:
341339
raise exception
@@ -362,7 +360,7 @@ async def receive(self, channel: str) -> typing.Any:
362360
del self.receive_buffer[channel]
363361
raise
364362
finally:
365-
self.receive_lock.release()
363+
self.receive_lock.release() # type: ignore[union-attr]
366364

367365
# We know there's a message available, because there
368366
# couldn't have been any interruption between empty() and here
@@ -377,7 +375,7 @@ async def receive(self, channel: str) -> typing.Any:
377375
self.receive_count -= 1
378376
# If we were the last out, drop the receive lock
379377
if self.receive_count == 0:
380-
assert not self.receive_lock.locked()
378+
assert not self.receive_lock.locked() # type: ignore[union-attr]
381379
self.receive_lock = None
382380
self.receive_event_loop = None
383381
else:
@@ -422,7 +420,7 @@ async def receive_single(
422420
)
423421
self.receive_cleaners.append(cleaner)
424422

425-
def _cleanup_done(cleaner: asyncio.Task[typing.Any]) -> None:
423+
def _cleanup_done(cleaner: "asyncio.Task") -> None:
426424
self.receive_cleaners.remove(cleaner)
427425
self.receive_clean_locks.release(channel_key)
428426

@@ -468,7 +466,7 @@ async def flush(self) -> None:
468466
# Go through each connection and remove all with prefix
469467
for i in range(self.ring_size):
470468
connection = self.connection(i)
471-
await connection.eval(delete_prefix, 0, self.prefix + "*")
469+
await connection.eval(delete_prefix, 0, self.prefix + "*") # type: ignore[union-attr,misc]
472470
# Now clear the pools as well
473471
await self.close_pools()
474472

@@ -584,7 +582,7 @@ async def group_send(self, group: str, message: typing.Any) -> None:
584582

585583
# channel_keys does not contain a single redis key more than once
586584
connection = self.connection(connection_index)
587-
channels_over_capacity = await connection.eval(
585+
channels_over_capacity = await connection.eval( # type: ignore[misc]
588586
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
589587
)
590588
_channels_over_capacity = -1.0

channels_redis/pubsub.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from copy import deepcopy
23
import functools
34
import logging
45
import typing
@@ -16,7 +17,7 @@
1617
)
1718

1819
if typing.TYPE_CHECKING:
19-
from redis.asyncio.client import Redis, PubSub
20+
from redis.asyncio.client import PubSub, Redis
2021
from typing_extensions import Buffer
2122
logger = logging.getLogger(__name__)
2223

@@ -81,10 +82,11 @@ def _get_layer(self) -> "RedisPubSubLoopLayer":
8182
try:
8283
layer = self._layers[loop]
8384
except KeyError:
85+
kwargs = deepcopy(self._kwargs)
86+
kwargs["channel_layer"] = self
8487
layer = RedisPubSubLoopLayer(
8588
*self._args,
86-
**self._kwargs,
87-
channel_layer=self,
89+
**kwargs,
8890
)
8991
self._layers[loop] = layer
9092
_wrap_close(self, loop)
@@ -163,7 +165,7 @@ async def send(
163165
Send a message onto a (general or specific) channel.
164166
"""
165167
shard = self._get_shard(channel)
166-
await shard.publish(channel, self.channel_layer.serialize(message))
168+
await shard.publish(channel, self.channel_layer.serialize(message)) # type: ignore[union-attr]
167169

168170
async def new_channel(self, prefix: str = "specific.") -> str:
169171
"""
@@ -205,7 +207,7 @@ async def receive(self, channel: typing.Union[str, bytes]) -> typing.Any:
205207
# We don't re-raise here because we want the CancelledError to be the one re-raised.
206208
raise
207209

208-
return self.channel_layer.deserialize(message)
210+
return self.channel_layer.deserialize(message) # type: ignore[union-attr]
209211

210212
################################################################################
211213
# Groups extension
@@ -254,7 +256,7 @@ async def group_send(self, group: str, message: typing.Any) -> None:
254256
"""
255257
group_channel = self._get_group_channel_name(group)
256258
shard = self._get_shard(group_channel)
257-
await shard.publish(group_channel, self.channel_layer.serialize(message))
259+
await shard.publish(group_channel, self.channel_layer.serialize(message)) # type: ignore[union-attr]
258260

259261
################################################################################
260262
# Flush extension
@@ -282,7 +284,7 @@ def __init__(
282284
self._lock = asyncio.Lock()
283285
self._redis: typing.Optional["Redis"] = None
284286
self._pubsub: typing.Optional["PubSub"] = None
285-
self._receive_task: typing.Optional[asyncio.Task[typing.Any]] = None
287+
self._receive_task: typing.Optional["asyncio.Task[typing.Any]"] = None
286288

287289
async def publish(
288290
self,
@@ -291,22 +293,22 @@ async def publish(
291293
) -> None:
292294
async with self._lock:
293295
self._ensure_redis()
294-
await self._redis.publish(channel, message)
296+
await self._redis.publish(channel, message) # type: ignore[union-attr]
295297

296298
async def subscribe(self, channel: typing.Union[str, bytes]) -> None:
297299
async with self._lock:
298300
if channel not in self._subscribed_to:
299301
self._ensure_redis()
300302
self._ensure_receiver()
301-
await self._pubsub.subscribe(channel)
303+
await self._pubsub.subscribe(channel) # type: ignore[union-attr]
302304
self._subscribed_to.add(channel)
303305

304306
async def unsubscribe(self, channel: typing.Union[str, bytes]) -> None:
305307
async with self._lock:
306308
if channel in self._subscribed_to:
307309
self._ensure_redis()
308310
self._ensure_receiver()
309-
await self._pubsub.unsubscribe(channel)
311+
await self._pubsub.unsubscribe(channel) # type: ignore[union-attr]
310312
self._subscribed_to.remove(channel)
311313

312314
async def flush(self) -> None:

channels_redis/serializers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ def as_bytes(self, message: typing.Any, *args, **kwargs) -> bytes:
128128
msg = json.dumps(message, *args, **kwargs)
129129
return msg.encode("utf-8")
130130

131-
from_bytes = staticmethod(json.loads)
131+
from_bytes = staticmethod(json.loads) # type: ignore[assignment]
132132

133133

134134
# code ready for a future in which msgpack may become an optional dependency
135135
MsgPackSerializer: typing.Union[
136136
typing.Type[BaseMessageSerializer], typing.Type[MissingSerializer]
137137
]
138138
try:
139-
import msgpack
139+
import msgpack # type: ignore[import-untyped]
140140
except ImportError as exc:
141141

142142
class _MsgPackSerializer(MissingSerializer):
@@ -146,8 +146,8 @@ class _MsgPackSerializer(MissingSerializer):
146146
else:
147147

148148
class __MsgPackSerializer(BaseMessageSerializer):
149-
as_bytes = staticmethod(msgpack.packb)
150-
from_bytes = staticmethod(msgpack.unpackb)
149+
as_bytes = staticmethod(msgpack.packb) # type: ignore[assignment]
150+
from_bytes = staticmethod(msgpack.unpackb) # type: ignore[assignment]
151151

152152
MsgPackSerializer = __MsgPackSerializer
153153

channels_redis/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _wrapper(self, *args, **kwargs) -> typing.Any:
4747
self.close = original_impl
4848
return self.close(*args, **kwargs)
4949

50-
loop.close = types.MethodType(_wrapper, loop)
50+
loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign]
5151

5252

5353
async def _close_redis(connection: "Redis") -> None:

tox.ini

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ envlist =
77
[testenv]
88
usedevelop = true
99
extras = tests
10+
allowlist_externals = mypy
1011
commands =
1112
pytest -v {posargs}
12-
mypy . {posargs}
13+
; mypy --show-error-codes --config-file=tox.ini channels_redis
1314
deps =
15+
mypy>=1.9.0
1416
ch30: channels>=3.0,<3.1
1517
ch40: channels>=4.0,<4.1
1618
chmain: https://github.com/django/channels/archive/main.tar.gz
@@ -28,3 +30,7 @@ commands =
2830
flake8 channels_redis tests
2931
black --check channels_redis tests
3032
isort --check-only --diff channels_redis tests
33+
34+
35+
[mypy]
36+
disallow_untyped_defs = False

0 commit comments

Comments
 (0)