33from dataclasses import dataclass , field
44from functools import partial
55from typing import Final
6+ from uuid import uuid4
67
78import aio_pika
89from pydantic import NonNegativeInt
910
1011from ..logging_utils import log_catch , log_context
1112from ._client_base import RabbitMQClientBase
12- from ._models import MessageHandler , RabbitMessage
13+ from ._models import (
14+ ConsumerTag ,
15+ ExchangeName ,
16+ MessageHandler ,
17+ QueueName ,
18+ RabbitMessage ,
19+ TopicName ,
20+ )
1321from ._utils import (
1422 RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS ,
1523 declare_queue ,
2634_DEFAULT_UNEXPECTED_ERROR_RETRY_DELAY_S : Final [float ] = 1
2735_DEFAULT_UNEXPECTED_ERROR_MAX_ATTEMPTS : Final [NonNegativeInt ] = 15
2836
29- _DELAYED_EXCHANGE_NAME : Final [str ] = "delayed_{exchange_name}"
37+ _DELAYED_EXCHANGE_NAME : Final [ExchangeName ] = ExchangeName ("delayed_{exchange_name}" )
38+ _DELAYED_QUEUE_NAME : Final [ExchangeName ] = ExchangeName ("delayed_{queue_name}" )
3039
3140
3241def _get_x_death_count (message : aio_pika .abc .AbstractIncomingMessage ) -> int :
@@ -138,25 +147,30 @@ async def _get_channel(self) -> aio_pika.abc.AbstractChannel:
138147 channel .close_callbacks .add (self ._channel_close_callback )
139148 return channel
140149
141- async def _get_consumer_tag (self , exchange_name ) -> str :
142- return f"{ get_rabbitmq_client_unique_name (self .client_name )} _{ exchange_name } "
150+ async def _create_consumer_tag (self , exchange_name ) -> ConsumerTag :
151+ return ConsumerTag (
152+ f"{ get_rabbitmq_client_unique_name (self .client_name )} _{ exchange_name } _{ uuid4 ()} "
153+ )
143154
144155 async def subscribe (
145156 self ,
146- exchange_name : str ,
157+ exchange_name : ExchangeName ,
147158 message_handler : MessageHandler ,
148159 * ,
149160 exclusive_queue : bool = True ,
161+ non_exclusive_queue_name : str | None = None ,
150162 topics : list [str ] | None = None ,
151163 message_ttl : NonNegativeInt = RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS ,
152164 unexpected_error_retry_delay_s : float = _DEFAULT_UNEXPECTED_ERROR_RETRY_DELAY_S ,
153165 unexpected_error_max_attempts : int = _DEFAULT_UNEXPECTED_ERROR_MAX_ATTEMPTS ,
154- ) -> str :
166+ ) -> tuple [ QueueName , ConsumerTag ] :
155167 """subscribe to exchange_name calling ``message_handler`` for every incoming message
156168 - exclusive_queue: True means that every instance of this application will
157169 receive the incoming messages
158170 - exclusive_queue: False means that only one instance of this application will
159171 reveice the incoming message
172+ - non_exclusive_queue_name: if exclusive_queue is False, then this name will be used. If None
173+ it will use the exchange_name.
160174
161175 NOTE: ``message_ttl` is also a soft timeout: if the handler does not finish processing
162176 the message before this is reached the message will be redelivered!
@@ -182,7 +196,7 @@ async def subscribe(
182196 aio_pika.exceptions.ChannelPreconditionFailed: In case an existing exchange with
183197 different type is used
184198 Returns:
185- queue name
199+ tuple of queue name and consumer tag mapping
186200 """
187201
188202 assert self ._channel_pool # nosec
@@ -212,7 +226,7 @@ async def subscribe(
212226 queue = await declare_queue (
213227 channel ,
214228 self .client_name ,
215- exchange_name ,
229+ non_exclusive_queue_name or exchange_name ,
216230 exclusive_queue = exclusive_queue ,
217231 message_ttl = message_ttl ,
218232 arguments = {"x-dead-letter-exchange" : delayed_exchange_name },
@@ -227,31 +241,33 @@ async def subscribe(
227241 delayed_exchange = await channel .declare_exchange (
228242 delayed_exchange_name , aio_pika .ExchangeType .FANOUT , durable = True
229243 )
244+ delayed_queue_name = _DELAYED_QUEUE_NAME .format (
245+ queue_name = non_exclusive_queue_name or exchange_name
246+ )
230247
231248 delayed_queue = await declare_queue (
232249 channel ,
233250 self .client_name ,
234- delayed_exchange_name ,
251+ delayed_queue_name ,
235252 exclusive_queue = exclusive_queue ,
236253 message_ttl = int (unexpected_error_retry_delay_s * 1000 ),
237254 arguments = {"x-dead-letter-exchange" : exchange .name },
238255 )
239256 await delayed_queue .bind (delayed_exchange )
240257
241- _consumer_tag = await self ._get_consumer_tag (exchange_name )
258+ consumer_tag = await self ._create_consumer_tag (exchange_name )
242259 await queue .consume (
243260 partial (_on_message , message_handler , unexpected_error_max_attempts ),
244261 exclusive = exclusive_queue ,
245- consumer_tag = _consumer_tag ,
262+ consumer_tag = consumer_tag ,
246263 )
247- output : str = queue .name
248- return output
264+ return queue .name , consumer_tag
249265
250266 async def add_topics (
251267 self ,
252- exchange_name : str ,
268+ exchange_name : ExchangeName ,
253269 * ,
254- topics : list [str ],
270+ topics : list [TopicName ],
255271 ) -> None :
256272 assert self ._channel_pool # nosec
257273
@@ -275,9 +291,9 @@ async def add_topics(
275291
276292 async def remove_topics (
277293 self ,
278- exchange_name : str ,
294+ exchange_name : ExchangeName ,
279295 * ,
280- topics : list [str ],
296+ topics : list [TopicName ],
281297 ) -> None :
282298 assert self ._channel_pool # nosec
283299 async with self ._channel_pool .acquire () as channel :
@@ -300,15 +316,24 @@ async def remove_topics(
300316
301317 async def unsubscribe (
302318 self ,
303- queue_name : str ,
319+ queue_name : QueueName ,
304320 ) -> None :
321+ """This will delete the queue if there are no consumers left"""
322+ assert self ._connection_pool # nosec
323+ if self ._connection_pool .is_closed :
324+ _logger .warning (
325+ "Connection to RabbitMQ is already closed, skipping unsubscribe from queue..."
326+ )
327+ return
305328 assert self ._channel_pool # nosec
306329 async with self ._channel_pool .acquire () as channel :
307330 queue = await channel .get_queue (queue_name )
308331 # NOTE: we force delete here
309332 await queue .delete (if_unused = False , if_empty = False )
310333
311- async def publish (self , exchange_name : str , message : RabbitMessage ) -> None :
334+ async def publish (
335+ self , exchange_name : ExchangeName , message : RabbitMessage
336+ ) -> None :
312337 """publish message in the exchange exchange_name.
313338 specifying a topic will use a TOPIC type of RabbitMQ Exchange instead of FANOUT
314339
@@ -333,10 +358,18 @@ async def publish(self, exchange_name: str, message: RabbitMessage) -> None:
333358 routing_key = message .routing_key () or "" ,
334359 )
335360
336- async def unsubscribe_consumer (self , exchange_name : str ):
361+ async def unsubscribe_consumer (
362+ self , queue_name : QueueName , consumer_tag : ConsumerTag
363+ ) -> None :
364+ """This will only remove the consumers without deleting the queue"""
365+ assert self ._connection_pool # nosec
366+ if self ._connection_pool .is_closed :
367+ _logger .warning (
368+ "Connection to RabbitMQ is already closed, skipping unsubscribe consumers from queue..."
369+ )
370+ return
337371 assert self ._channel_pool # nosec
338372 async with self ._channel_pool .acquire () as channel :
339- queue_name = exchange_name
373+ assert isinstance ( channel , aio_pika . RobustChannel ) # nosec
340374 queue = await channel .get_queue (queue_name )
341- _consumer_tag = await self ._get_consumer_tag (exchange_name )
342- await queue .cancel (_consumer_tag )
375+ await queue .cancel (consumer_tag )
0 commit comments