Skip to content

Commit 21ff64d

Browse files
committed
FIX REVIEW
1 parent 4da1969 commit 21ff64d

File tree

4 files changed

+71
-60
lines changed

4 files changed

+71
-60
lines changed

channels_redis/core.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
create_pool,
2020
decode_hosts,
2121
)
22-
from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional
2322

24-
if TYPE_CHECKING:
23+
import typing
24+
25+
if typing.TYPE_CHECKING:
2526
from redis.asyncio.connection import ConnectionPool
2627
from redis.asyncio.client import Redis
27-
from .core import RedisChannelLayer
2828
from typing_extensions import Buffer
2929

3030
logger = logging.getLogger(__name__)
@@ -39,10 +39,10 @@ class ChannelLock:
3939
"""
4040

4141
def __init__(self):
42-
self.locks: "collections.defaultdict[str, asyncio.Lock]" = (
42+
self.locks: collections.defaultdict[str, asyncio.Lock] = (
4343
collections.defaultdict(asyncio.Lock)
4444
)
45-
self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict(
45+
self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict(
4646
int
4747
)
4848

@@ -87,7 +87,7 @@ class RedisLoopLayer:
8787
def __init__(self, channel_layer: "RedisChannelLayer"):
8888
self._lock = asyncio.Lock()
8989
self.channel_layer = channel_layer
90-
self._connections: "Dict[int, Redis]" = {}
90+
self._connections: typing.Dict[int, "Redis"] = {}
9191

9292
def get_connection(self, index: int) -> "Redis":
9393
if index not in self._connections:
@@ -145,7 +145,7 @@ def __init__(
145145
symmetric_encryption_keys=symmetric_encryption_keys,
146146
)
147147
# Cached redis connection pools and the event loop they are from
148-
self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
148+
self._layers: typing.Dict[asyncio.AbstractEventLoop, "RedisLoopLayer"] = {}
149149
# Normal channels choose a host index by cycling through the available hosts
150150
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
151151
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -154,15 +154,15 @@ def __init__(
154154
# Number of coroutines trying to receive right now
155155
self.receive_count = 0
156156
# The receive lock
157-
self.receive_lock: "Optional[asyncio.Lock]" = None
157+
self.receive_lock: typing.Optional[asyncio.Lock] = None
158158
# Event loop they are trying to receive on
159-
self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None
159+
self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None
160160
# Buffered messages by process-local channel name
161-
self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = (
161+
self.receive_buffer: collections.defaultdict[str, BoundedQueue] = (
162162
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
163163
)
164164
# Detached channel cleanup tasks
165-
self.receive_cleaners: "List[asyncio.Task]" = []
165+
self.receive_cleaners: typing.List[asyncio.Task] = []
166166
# Per-channel cleanup locks to prevent a receive starting and moving
167167
# a message back into the main queue before its cleanup has completed
168168
self.receive_clean_locks = ChannelLock()
@@ -180,7 +180,7 @@ async def send(self, channel: str, message):
180180
"""
181181
# Typecheck
182182
assert isinstance(message, dict), "message is not a dict"
183-
assert self.require_valid_channel_name(channel), "Channel name not valid"
183+
assert self.valid_channel_name(channel), "Channel name not valid"
184184
# Make sure the message does not contain reserved keys
185185
assert "__asgi_channel__" not in message
186186
# If it's a process-local channel, strip off local part and stick full name in message
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221221
return channel + "$inflight"
222222

223223
async def _brpop_with_clean(
224-
self, index: int, channel: str, timeout: "Union[int, float, bytes, str]"
224+
self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str]
225225
):
226226
"""
227227
Perform a Redis BRPOP and manage the backup processing queue.
@@ -269,7 +269,7 @@ async def receive(self, channel: str):
269269
"""
270270
# Make sure the channel name is valid then get the non-local part
271271
# and thus its index
272-
assert self.require_valid_channel_name(channel)
272+
assert self.valid_channel_name(channel)
273273
if "!" in channel:
274274
real_channel = self.non_local_name(channel)
275275
assert real_channel.endswith(
@@ -385,14 +385,12 @@ async def receive(self, channel: str):
385385
# Do a plain direct receive
386386
return (await self.receive_single(channel))[1]
387387

388-
async def receive_single(self, channel: str) -> "Tuple":
388+
async def receive_single(self, channel: str) -> typing.Tuple:
389389
"""
390390
Receives a single message off of the channel and returns it.
391391
"""
392392
# Check channel name
393-
assert self.require_valid_channel_name(
394-
channel, receive=True
395-
), "Channel name invalid"
393+
assert self.valid_channel_name(channel, receive=True), "Channel name invalid"
396394
# Work out the connection to use
397395
if "!" in channel:
398396
assert channel.endswith("!")
@@ -423,7 +421,7 @@ async def receive_single(self, channel: str) -> "Tuple":
423421
)
424422
self.receive_cleaners.append(cleaner)
425423

426-
def _cleanup_done(cleaner: "asyncio.Task"):
424+
def _cleanup_done(cleaner: asyncio.Task):
427425
self.receive_cleaners.remove(cleaner)
428426
self.receive_clean_locks.release(channel_key)
429427

@@ -497,8 +495,8 @@ async def group_add(self, group: str, channel: str):
497495
Adds the channel name to a group.
498496
"""
499497
# Check the inputs
500-
assert self.require_valid_group_name(group), True
501-
assert self.require_valid_channel_name(channel), True
498+
assert self.valid_group_name(group), True
499+
assert self.valid_channel_name(channel), True
502500
# Get a connection to the right shard
503501
group_key = self._group_key(group)
504502
connection = self.connection(self.consistent_hash(group))
@@ -513,8 +511,8 @@ async def group_discard(self, group: str, channel: str):
513511
Removes the channel from the named group if it is in the group;
514512
does nothing otherwise (does not error)
515513
"""
516-
assert self.require_valid_group_name(group), "Group name not valid"
517-
assert self.require_valid_channel_name(channel), "Channel name not valid"
514+
assert self.valid_group_name(group), "Group name not valid"
515+
assert self.valid_channel_name(channel), "Channel name not valid"
518516
key = self._group_key(group)
519517
connection = self.connection(self.consistent_hash(group))
520518
await connection.zrem(key, channel)
@@ -523,7 +521,7 @@ async def group_send(self, group: str, message):
523521
"""
524522
Sends a message to the entire group.
525523
"""
526-
assert self.require_valid_group_name(group), "Group name not valid"
524+
assert self.valid_group_name(group), "Group name not valid"
527525
# Retrieve list of all channel names
528526
key = self._group_key(group)
529527
connection = self.connection(self.consistent_hash(group))
@@ -671,7 +669,7 @@ def deserialize(self, message: bytes):
671669

672670
### Internal functions ###
673671

674-
def consistent_hash(self, value: "Union[str, Buffer]") -> int:
672+
def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int:
675673
return _consistent_hash(value, self.ring_size)
676674

677675
def __str__(self):

channels_redis/pubsub.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
create_pool,
1414
decode_hosts,
1515
)
16-
from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable
1716

18-
if TYPE_CHECKING:
17+
import typing
18+
19+
if typing.TYPE_CHECKING:
1920
from redis.asyncio.client import Redis
2021
from typing_extensions import Buffer
2122
logger = logging.getLogger(__name__)
2223

2324

24-
async def _async_proxy(obj: "RedisPubSubChannelLayer", name: "str", *args, **kwargs):
25+
async def _async_proxy(obj: "RedisPubSubChannelLayer", name: str, *args, **kwargs):
2526
# Must be defined as a function and not a method due to
2627
# https://bugs.python.org/issue38364
2728
layer = obj._get_layer()
@@ -32,13 +33,15 @@ class RedisPubSubChannelLayer:
3233
def __init__(
3334
self,
3435
*args,
35-
symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None,
36+
symmetric_encryption_keys: typing.Optional[
37+
typing.Iterable[typing.Union[str, "Buffer"]]
38+
] = None,
3639
serializer_format="msgpack",
3740
**kwargs,
3841
):
3942
self._args = args
4043
self._kwargs = kwargs
41-
self._layers: "Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer]" = {}
44+
self._layers: typing.Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer] = {}
4245
# serialization
4346
self._serializer = registry.get_serializer(
4447
serializer_format,
@@ -95,11 +98,11 @@ class RedisPubSubLoopLayer:
9598

9699
def __init__(
97100
self,
98-
hosts: "Union[Iterable, str, bytes, None]" = None,
99-
prefix="asgi",
101+
hosts: typing.Union[typing.Iterable, str, bytes, None] = None,
102+
prefix: str = "asgi",
100103
on_disconnect=None,
101104
on_reconnect=None,
102-
channel_layer: "Optional[RedisPubSubChannelLayer]" = None,
105+
channel_layer: typing.Optional[RedisPubSubChannelLayer] = None,
103106
**kwargs,
104107
):
105108
self.prefix = prefix
@@ -110,7 +113,7 @@ def __init__(
110113

111114
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
112115
# This dict maps `channel_name` to a queue of messages for that channel.
113-
self.channels: "Dict[Any, asyncio.Queue]" = {}
116+
self.channels: typing.Dict[typing.Any, asyncio.Queue] = {}
114117

115118
# A channel can subscribe to zero or more groups.
116119
# This dict maps `group_name` to set of channel names who are subscribed to that group.
@@ -122,7 +125,7 @@ def __init__(
122125
]
123126

124127
def _get_shard(
125-
self, channel_or_group_name: "Union[str, Buffer]"
128+
self, channel_or_group_name: typing.Union[str, "Buffer"]
126129
) -> "RedisSingleShardConnection":
127130
"""
128131
Return the shard that is used exclusively for this channel or group.
@@ -265,17 +268,21 @@ async def flush(self):
265268

266269

267270
class RedisSingleShardConnection:
268-
def __init__(self, host: "Dict[str, Any]", channel_layer: "RedisPubSubLoopLayer"):
271+
def __init__(
272+
self, host: typing.Dict[str, typing.Any], channel_layer: RedisPubSubLoopLayer
273+
):
269274
self.host = host
270275
self.channel_layer = channel_layer
271276
self._subscribed_to = set()
272277
self._lock = asyncio.Lock()
273-
self._redis: "Optional[Redis]" = None
278+
self._redis: typing.Optional["Redis"] = None
274279
self._pubsub = None
275-
self._receive_task: "Optional[asyncio.Task]" = None
280+
self._receive_task: typing.Optional[asyncio.Task] = None
276281

277282
async def publish(
278-
self, channel: "Union[str, bytes]", message: "Union[str, bytes, int, float]"
283+
self,
284+
channel: typing.Union[str, bytes],
285+
message: typing.Union[str, bytes, int, float],
279286
):
280287
async with self._lock:
281288
self._ensure_redis()
@@ -335,7 +342,7 @@ async def _do_receiving(self):
335342
logger.exception("Unexpected exception in receive task")
336343
await asyncio.sleep(1)
337344

338-
def _receive_message(self, message: "Optional[Dict]"):
345+
def _receive_message(self, message: typing.Optional[typing.Dict]):
339346
if message is not None:
340347
name = message["channel"]
341348
data = message["data"]

channels_redis/serializers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import hashlib
44
import json
55
import random
6-
from typing import TYPE_CHECKING, Dict, Union, Optional, Any, Iterable, Type, cast
6+
import typing
77

8-
if TYPE_CHECKING:
8+
if typing.TYPE_CHECKING:
99
from typing_extensions import Buffer
1010

1111
try:
@@ -24,19 +24,23 @@ class SerializerDoesNotExist(KeyError):
2424
class BaseMessageSerializer(abc.ABC):
2525
def __init__(
2626
self,
27-
symmetric_encryption_keys: "Optional[Iterable[Union[str, Buffer]]]" = None,
27+
symmetric_encryption_keys: typing.Optional[
28+
typing.Iterable[typing.Union[str, "Buffer"]]
29+
] = None,
2830
random_prefix_length: int = 0,
29-
expiry: "Optional[int]" = None,
31+
expiry: typing.Optional[int] = None,
3032
):
3133
self.random_prefix_length = random_prefix_length
3234
self.expiry = expiry
33-
self.crypter: "Optional[MultiFernet]" = None
35+
self.crypter: typing.Optional["MultiFernet"] = None
3436
# Set up any encryption objects
3537
self._setup_encryption(symmetric_encryption_keys)
3638

3739
def _setup_encryption(
3840
self,
39-
symmetric_encryption_keys: "Optional[Union[Iterable[Union[str, Buffer]], str, bytes]]",
41+
symmetric_encryption_keys: typing.Optional[
42+
typing.Union[typing.Iterable[typing.Union[str, "Buffer"]], str, bytes]
43+
],
4044
):
4145
# See if we can do encryption if they asked
4246
if not symmetric_encryption_keys:
@@ -47,15 +51,17 @@ def _setup_encryption(
4751
raise ValueError(
4852
"symmetric_encryption_keys must be a list of possible keys"
4953
)
50-
keys = cast("Iterable[Union[str, Buffer]]", symmetric_encryption_keys)
54+
keys = typing.cast(
55+
typing.Iterable[typing.Union[str, "Buffer"]], symmetric_encryption_keys
56+
)
5157
if _MultiFernet is None:
5258
raise ValueError(
5359
"Cannot run with encryption without 'cryptography' installed."
5460
)
5561
sub_fernets = [self.make_fernet(key) for key in keys]
5662
self.crypter = _MultiFernet(sub_fernets)
5763

58-
def make_fernet(self, key: "Union[str, Buffer]") -> "Fernet":
64+
def make_fernet(self, key: typing.Union[str, "Buffer"]) -> "Fernet":
5965
"""
6066
Given a single encryption key, returns a Fernet instance using it.
6167
"""
@@ -129,7 +135,7 @@ def as_bytes(self, message, *args, **kwargs) -> bytes:
129135

130136

131137
# code ready for a future in which msgpack may become an optional dependency
132-
MsgPackSerializer: "Optional[Type[BaseMessageSerializer]]" = None
138+
MsgPackSerializer: typing.Optional[typing.Type[BaseMessageSerializer]] = None
133139
try:
134140
import msgpack
135141
except ImportError as exc:
@@ -153,10 +159,10 @@ class SerializersRegistry:
153159
"""
154160

155161
def __init__(self):
156-
self._registry: "Dict[Any, Type[BaseMessageSerializer]]" = {}
162+
self._registry: typing.Dict[typing.Any, typing.Type[BaseMessageSerializer]] = {}
157163

158164
def register_serializer(
159-
self, format, serializer_class: "Type[BaseMessageSerializer]"
165+
self, format, serializer_class: typing.Type[BaseMessageSerializer]
160166
):
161167
"""
162168
Register a new serializer for given format
@@ -174,7 +180,7 @@ def register_serializer(
174180

175181
self._registry[format] = serializer_class
176182

177-
def get_serializer(self, format, *args, **kwargs) -> "BaseMessageSerializer":
183+
def get_serializer(self, format, *args, **kwargs) -> BaseMessageSerializer:
178184
try:
179185
serializer_class = self._registry[format]
180186
except KeyError:

0 commit comments

Comments
 (0)