@@ -137,9 +137,8 @@ def __init__(
137
137
self ._setup_encryption (symmetric_encryption_keys )
138
138
# Number of coroutines trying to receive right now
139
139
self .receive_count = 0
140
- # The receiving token queue
141
- # Whoever holds its tokens can receive from a local channel
142
- self .receive_tokens = None
140
+ # The receive lock
141
+ self .receive_lock = None
143
142
# Event loop they are trying to receive on
144
143
self .receive_event_loop = None
145
144
# Buffered messages by process-local channel name
@@ -234,40 +233,87 @@ async def receive(self, channel):
234
233
self .receive_count += 1
235
234
try :
236
235
if self .receive_count == 1 :
237
- # If we're the first coroutine in, create the sharing token!
238
- self .receive_tokens = asyncio .Queue ()
239
- self .receive_tokens .put_nowait (True )
236
+ # If we're the first coroutine in, create the receive lock!
237
+ self .receive_lock = asyncio .Lock ()
240
238
self .receive_event_loop = loop
241
239
else :
242
240
# Otherwise, check our event loop matches
243
241
if self .receive_event_loop != loop :
244
242
raise RuntimeError ("Two event loops are trying to receive() on one channel layer at once!" )
245
243
246
244
# Wait for our message to appear
245
+ message = None
247
246
while self .receive_buffer [channel ].empty ():
248
- token = await self .receive_tokens .get ()
247
+ tasks = [self .receive_lock .acquire (), self .receive_buffer [channel ].get ()]
248
+ tasks = [asyncio .ensure_future (task ) for task in tasks ]
249
249
try :
250
- message_channel , message = await self .receive_single (real_channel )
251
- if type (message_channel ) is list :
252
- for chan in message_channel :
253
- await self .receive_buffer [chan ].put (message )
250
+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
251
+ for task in pending :
252
+ # Cancel all pending tasks.
253
+ task .cancel ()
254
+ except asyncio .CancelledError :
255
+ # Ensure all tasks are cancelled if we are cancelled.
256
+ # Also see: https://bugs.python.org/issue23859
257
+ for task in tasks :
258
+ task .cancel ()
259
+
260
+ raise
261
+
262
+ message , token , exception = None , None , None
263
+ for task in done :
264
+ try :
265
+ result = task .result ()
266
+ except Exception as error : # NOQA
267
+ # We should not propagate exceptions immediately as otherwise this may cause
268
+ # the lock to be held and never be released.
269
+ exception = error
270
+ continue
271
+
272
+ if result is True :
273
+ token = result
254
274
else :
255
- await self .receive_buffer [message_channel ].put (message )
256
- finally :
257
- await self .receive_tokens .put (token )
275
+ assert isinstance (result , dict )
276
+ message = result
277
+
278
+ if message or exception :
279
+ if token :
280
+ # We will not be receving as we already have the message.
281
+ self .receive_lock .release ()
282
+
283
+ if exception :
284
+ raise exception
285
+ else :
286
+ break
287
+ else :
288
+ assert token
289
+
290
+ # We hold the receive lock, receive and then release it.
291
+ try :
292
+ message_channel , message = await self .receive_single (real_channel )
293
+ if type (message_channel ) is list :
294
+ for chan in message_channel :
295
+ await self .receive_buffer [chan ].put (message )
296
+ else :
297
+ await self .receive_buffer [message_channel ].put (message )
298
+ message = None
299
+ finally :
300
+ self .receive_lock .release ()
258
301
259
302
# We know there's a message available, because there
260
303
# couldn't have been any interruption between empty() and here
261
- message = self .receive_buffer [channel ].get_nowait ()
304
+ if message is None :
305
+ message = self .receive_buffer [channel ].get_nowait ()
306
+
262
307
if self .receive_buffer [channel ].empty ():
263
308
del self .receive_buffer [channel ]
264
309
return message
265
310
266
311
finally :
267
312
self .receive_count -= 1
268
- # If we were the last out, stop the receive loop
313
+ # If we were the last out, drop the receive lock
269
314
if self .receive_count == 0 :
270
- self .receive_tokens = None
315
+ assert not self .receive_lock .locked ()
316
+ self .receive_lock = None
271
317
else :
272
318
# Do a plain direct receive
273
319
return (await self .receive_single (channel ))[1 ]
0 commit comments