@@ -168,7 +168,9 @@ def __init__(
168
168
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
169
169
# Decide on a unique client prefix to use in ! sections
170
170
# TODO: ensure uniqueness better, e.g. Redis keys with SETNX
171
- self .client_prefix = "" .join (random .choice (string .ascii_letters ) for i in range (8 ))
171
+ self .client_prefix = "" .join (
172
+ random .choice (string .ascii_letters ) for i in range (8 )
173
+ )
172
174
# Set up any encryption objects
173
175
self ._setup_encryption (symmetric_encryption_keys )
174
176
# Number of coroutines trying to receive right now
@@ -195,27 +197,31 @@ def decode_hosts(self, hosts):
195
197
return [{"address" : ("localhost" , 6379 )}]
196
198
# If they provided just a string, scold them.
197
199
if isinstance (hosts , (str , bytes )):
198
- raise ValueError ("You must pass a list of Redis hosts, even if there is only one." )
200
+ raise ValueError (
201
+ "You must pass a list of Redis hosts, even if there is only one."
202
+ )
199
203
# Decode each hosts entry into a kwargs dict
200
204
result = []
201
205
for entry in hosts :
202
206
if isinstance (entry , dict ):
203
207
result .append (entry )
204
208
else :
205
- result .append ({
206
- "address" : entry ,
207
- })
209
+ result .append ({"address" : entry })
208
210
return result
209
211
210
212
def _setup_encryption (self , symmetric_encryption_keys ):
211
213
# See if we can do encryption if they asked
212
214
if symmetric_encryption_keys :
213
215
if isinstance (symmetric_encryption_keys , (str , bytes )):
214
- raise ValueError ("symmetric_encryption_keys must be a list of possible keys" )
216
+ raise ValueError (
217
+ "symmetric_encryption_keys must be a list of possible keys"
218
+ )
215
219
try :
216
220
from cryptography .fernet import MultiFernet
217
221
except ImportError :
218
- raise ValueError ("Cannot run with encryption without 'cryptography' installed." )
222
+ raise ValueError (
223
+ "Cannot run with encryption without 'cryptography' installed."
224
+ )
219
225
sub_fernets = [self .make_fernet (key ) for key in symmetric_encryption_keys ]
220
226
self .crypter = MultiFernet (sub_fernets )
221
227
else :
@@ -282,11 +288,7 @@ async def _brpop_with_clean(self, index, channel, timeout):
282
288
async with self .connection (index ) as connection :
283
289
# Cancellation here doesn't matter, we're not doing anything destructive
284
290
# and the script executes atomically...
285
- await connection .eval (
286
- cleanup_script ,
287
- keys = [],
288
- args = [channel , backup_queue ]
289
- )
291
+ await connection .eval (cleanup_script , keys = [], args = [channel , backup_queue ])
290
292
# ...and it doesn't matter here either, the message will be safe in the backup.
291
293
return await connection .brpoplpush (channel , backup_queue , timeout = timeout )
292
294
@@ -309,7 +311,9 @@ async def receive(self, channel):
309
311
assert self .valid_channel_name (channel )
310
312
if "!" in channel :
311
313
real_channel = self .non_local_name (channel )
312
- assert real_channel .endswith (self .client_prefix + "!" ), "Wrong client prefix"
314
+ assert real_channel .endswith (
315
+ self .client_prefix + "!"
316
+ ), "Wrong client prefix"
313
317
# Enter receiving section
314
318
loop = asyncio .get_event_loop ()
315
319
self .receive_count += 1
@@ -321,15 +325,22 @@ async def receive(self, channel):
321
325
else :
322
326
# Otherwise, check our event loop matches
323
327
if self .receive_event_loop != loop :
324
- raise RuntimeError ("Two event loops are trying to receive() on one channel layer at once!" )
328
+ raise RuntimeError (
329
+ "Two event loops are trying to receive() on one channel layer at once!"
330
+ )
325
331
326
332
# Wait for our message to appear
327
333
message = None
328
334
while self .receive_buffer [channel ].empty ():
329
- tasks = [self .receive_lock .acquire (), self .receive_buffer [channel ].get ()]
335
+ tasks = [
336
+ self .receive_lock .acquire (),
337
+ self .receive_buffer [channel ].get (),
338
+ ]
330
339
tasks = [asyncio .ensure_future (task ) for task in tasks ]
331
340
try :
332
- done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
341
+ done , pending = await asyncio .wait (
342
+ tasks , return_when = asyncio .FIRST_COMPLETED
343
+ )
333
344
for task in pending :
334
345
# Cancel all pending tasks.
335
346
task .cancel ()
@@ -374,7 +385,9 @@ async def receive(self, channel):
374
385
# There is no interruption point from when the message is
375
386
# unpacked in receive_single to when we get back here, so
376
387
# the following lines are essentially atomic.
377
- message_channel , message = await self .receive_single (real_channel )
388
+ message_channel , message = await self .receive_single (
389
+ real_channel
390
+ )
378
391
if type (message_channel ) is list :
379
392
for chan in message_channel :
380
393
self .receive_buffer [chan ].put_nowait (message )
@@ -424,7 +437,9 @@ async def receive_single(self, channel):
424
437
while content is None :
425
438
# Nothing is lost here by cancellations, messages will still
426
439
# be in the backup queue.
427
- content = await self ._brpop_with_clean (index , channel_key , timeout = self .brpop_timeout )
440
+ content = await self ._brpop_with_clean (
441
+ index , channel_key , timeout = self .brpop_timeout
442
+ )
428
443
429
444
# Fire off a task to clean the message from its backup queue.
430
445
# Per-channel locking isn't needed, because the backup is a queue
@@ -433,7 +448,9 @@ async def receive_single(self, channel):
433
448
# removed after the next one.
434
449
# NOTE: Duplicate messages will be received eventually if any
435
450
# of these cleaners are cancelled.
436
- cleaner = asyncio .ensure_future (self ._clean_receive_backup (index , channel_key ))
451
+ cleaner = asyncio .ensure_future (
452
+ self ._clean_receive_backup (index , channel_key )
453
+ )
437
454
self .receive_cleaners .append (cleaner )
438
455
439
456
def _cleanup_done (cleaner ):
@@ -442,7 +459,7 @@ def _cleanup_done(cleaner):
442
459
443
460
cleaner .add_done_callback (_cleanup_done )
444
461
445
- except Exception as exc :
462
+ except Exception :
446
463
self .receive_clean_locks .release (channel_key )
447
464
raise
448
465
@@ -487,11 +504,7 @@ async def flush(self):
487
504
# Go through each connection and remove all with prefix
488
505
for i in range (self .ring_size ):
489
506
async with self .connection (i ) as connection :
490
- await connection .eval (
491
- delete_prefix ,
492
- keys = [],
493
- args = [self .prefix + "*" ]
494
- )
507
+ await connection .eval (delete_prefix , keys = [], args = [self .prefix + "*" ])
495
508
# Now clear the pools as well
496
509
await self .close_pools ()
497
510
@@ -526,11 +539,7 @@ async def group_add(self, group, channel):
526
539
group_key = self ._group_key (group )
527
540
async with self .connection (self .consistent_hash (group )) as connection :
528
541
# Add to group sorted set with creation time as timestamp
529
- await connection .zadd (
530
- group_key ,
531
- time .time (),
532
- channel ,
533
- )
542
+ await connection .zadd (group_key , time .time (), channel )
534
543
# Set expiration to be group_expiry, since everything in
535
544
# it at this point is guaranteed to expire before that
536
545
await connection .expire (group_key , self .group_expiry )
@@ -544,10 +553,7 @@ async def group_discard(self, group, channel):
544
553
assert self .valid_channel_name (channel ), "Channel name not valid"
545
554
key = self ._group_key (group )
546
555
async with self .connection (self .consistent_hash (group )) as connection :
547
- await connection .zrem (
548
- key ,
549
- channel ,
550
- )
556
+ await connection .zrem (key , channel )
551
557
552
558
async def group_send (self , group , message ):
553
559
"""
@@ -558,12 +564,17 @@ async def group_send(self, group, message):
558
564
key = self ._group_key (group )
559
565
async with self .connection (self .consistent_hash (group )) as connection :
560
566
# Discard old channels based on group_expiry
561
- await connection .zremrangebyscore (key , min = 0 , max = int (time .time ()) - self .group_expiry )
567
+ await connection .zremrangebyscore (
568
+ key , min = 0 , max = int (time .time ()) - self .group_expiry
569
+ )
562
570
563
- channel_names = [x .decode ("utf8" ) for x in await connection .zrange (key , 0 , - 1 )]
571
+ channel_names = [
572
+ x .decode ("utf8" ) for x in await connection .zrange (key , 0 , - 1 )
573
+ ]
564
574
565
- connection_to_channel_keys , channel_keys_to_message , channel_keys_to_capacity = \
566
- self ._map_channel_keys_to_connection (channel_names , message )
575
+ connection_to_channel_keys , channel_keys_to_message , channel_keys_to_capacity = self ._map_channel_keys_to_connection (
576
+ channel_names , message
577
+ )
567
578
568
579
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
569
580
@@ -572,24 +583,35 @@ async def group_send(self, group, message):
572
583
# stored in channel_to_message dict and contains the
573
584
# __asgi_channel__ key.
574
585
575
- group_send_lua = """
586
+ group_send_lua = (
587
+ """
576
588
for i=1,#KEYS do
577
589
if redis.call('LLEN', KEYS[i]) < tonumber(ARGV[i + #KEYS]) then
578
590
redis.call('LPUSH', KEYS[i], ARGV[i])
579
591
redis.call('EXPIRE', KEYS[i], %d)
580
592
end
581
593
end
582
- """ % self .expiry
594
+ """
595
+ % self .expiry
596
+ )
583
597
584
598
# We need to filter the messages to keep those related to the connection
585
- args = [channel_keys_to_message [channel_key ] for channel_key in channel_redis_keys ]
599
+ args = [
600
+ channel_keys_to_message [channel_key ]
601
+ for channel_key in channel_redis_keys
602
+ ]
586
603
587
604
# We need to send the capacity for each channel
588
- args += [channel_keys_to_capacity [channel_key ] for channel_key in channel_redis_keys ]
605
+ args += [
606
+ channel_keys_to_capacity [channel_key ]
607
+ for channel_key in channel_redis_keys
608
+ ]
589
609
590
610
# channel_keys does not contain a single redis key more than once
591
611
async with self .connection (connection_index ) as connection :
592
- await connection .eval (group_send_lua , keys = channel_redis_keys , args = args )
612
+ await connection .eval (
613
+ group_send_lua , keys = channel_redis_keys , args = args
614
+ )
593
615
594
616
def _map_channel_to_connection (self , channel_names , message ):
595
617
"""
@@ -619,7 +641,12 @@ def _map_channel_to_connection(self, channel_names, message):
619
641
# We build a
620
642
channel_to_key [channel ] = channel_key
621
643
622
- return connection_to_channels , channel_to_message , channel_to_capacity , channel_to_key
644
+ return (
645
+ connection_to_channels ,
646
+ channel_to_message ,
647
+ channel_to_capacity ,
648
+ channel_to_key ,
649
+ )
623
650
624
651
def _map_channel_keys_to_connection (self , channel_names , message ):
625
652
"""
@@ -665,7 +692,11 @@ def _map_channel_keys_to_connection(self, channel_names, message):
665
692
# Serialize the message stored for each redis key
666
693
channel_key_to_message [key ] = self .serialize (channel_key_to_message [key ])
667
694
668
- return connection_to_channel_keys , channel_key_to_message , channel_key_to_capacity
695
+ return (
696
+ connection_to_channel_keys ,
697
+ channel_key_to_message ,
698
+ channel_key_to_capacity ,
699
+ )
669
700
670
701
def _group_key (self , group ):
671
702
"""
@@ -710,6 +741,7 @@ def make_fernet(self, key):
710
741
Given a single encryption key, returns a Fernet instance using it.
711
742
"""
712
743
from cryptography .fernet import Fernet
744
+
713
745
if isinstance (key , str ):
714
746
key = key .encode ("utf8" )
715
747
formatted_key = base64 .urlsafe_b64encode (hashlib .sha256 (key ).digest ())
@@ -727,14 +759,17 @@ def connection(self, index):
727
759
"""
728
760
# Catch bad indexes
729
761
if not 0 <= index < self .ring_size :
730
- raise ValueError ("There are only %s hosts - you asked for %s!" % (self .ring_size , index ))
762
+ raise ValueError (
763
+ "There are only %s hosts - you asked for %s!" % (self .ring_size , index )
764
+ )
731
765
# Make a context manager
732
766
return self .ConnectionContextManager (self .pools [index ])
733
767
734
768
class ConnectionContextManager :
735
769
"""
736
770
Async context manager for connections
737
771
"""
772
+
738
773
def __init__ (self , pool ):
739
774
self .pool = pool
740
775
0 commit comments