@@ -180,22 +180,39 @@ async def send(self, channel, message):
180
180
"""
181
181
Send a message onto a (general or specific) channel.
182
182
"""
183
+ await self .send_multiple (channel , (message ,))
184
+
185
+ async def send_multiple (self , channel , messages ):
186
+ """
187
+ Send multiple messages at once onto a (general or specific) channel.
188
+ """
183
189
# Typecheck
184
- assert isinstance (message , dict ), "message is not a dict"
185
190
assert self .valid_channel_name (channel ), "Channel name not valid"
186
- # Make sure the message does not contain reserved keys
187
- assert "__asgi_channel__" not in message
191
+ assert hasattr ( messages , "__iter__" ), "messages is not an iterable"
192
+
188
193
# If it's a process-local channel, strip off local part and stick full name in message
189
194
channel_non_local_name = channel
190
- if "!" in channel :
191
- message = dict (message .items ())
192
- message ["__asgi_channel__" ] = channel
195
+ process_local = "!" in channel
196
+ if process_local :
193
197
channel_non_local_name = self .non_local_name (channel )
198
+
199
+ now = time .time ()
200
+ mapping = {}
201
+ for message in messages :
202
+ assert isinstance (message , dict ), "message is not a dict"
203
+ # Make sure the message does not contain reserved keys
204
+ assert "__asgi_channel__" not in message
205
+ if process_local :
206
+ message = dict (message .items ())
207
+ message ["__asgi_channel__" ] = channel
208
+
209
+ mapping [self .serialize (message )] = now
210
+
194
211
# Write out message into expiring key (avoids big items in list)
195
212
channel_key = self .prefix + channel_non_local_name
196
213
# Pick a connection to the right server - consistent for specific
197
214
# channels, random for general channels
198
- if "!" in channel :
215
+ if process_local :
199
216
index = self .consistent_hash (channel )
200
217
else :
201
218
index = next (self ._send_index_generator )
@@ -207,13 +224,13 @@ async def send(self, channel, message):
207
224
208
225
# Check the length of the list before send
209
226
# This can allow the list to leak slightly over capacity, but that's fine.
210
- if await connection .zcount (channel_key , "-inf" , "+inf" ) >= self . get_capacity (
211
- channel
212
- ):
227
+ current_length = await connection .zcount (channel_key , "-inf" , "+inf" )
228
+
229
+ if current_length + len ( messages ) > self . get_capacity ( channel ):
213
230
raise ChannelFull ()
214
231
215
232
# Push onto the list then set it to expire in case it's not consumed
216
- await connection .zadd (channel_key , { self . serialize ( message ): time . time ()} )
233
+ await connection .zadd (channel_key , mapping )
217
234
await connection .expire (channel_key , int (self .expiry ))
218
235
219
236
def _backup_channel_name (self , channel ):
0 commit comments