diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index e71b34d952..165ed87411 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -772,6 +772,8 @@ search: - [RabbitBroker](api/faststream/rabbit/broker/RabbitBroker.md) - broker - [RabbitBroker](api/faststream/rabbit/broker/broker/RabbitBroker.md) + - connection + - [ConnectionManager](api/faststream/rabbit/broker/connection/ConnectionManager.md) - logging - [RabbitLoggingBroker](api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md) - registrator diff --git a/docs/docs/en/api/faststream/rabbit/broker/connection/ConnectionManager.md b/docs/docs/en/api/faststream/rabbit/broker/connection/ConnectionManager.md new file mode 100644 index 0000000000..455fb32242 --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/broker/connection/ConnectionManager.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.broker.connection.ConnectionManager diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index 3d62e1e1d3..d9554cb7f4 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -126,10 +126,12 @@ def __init__( ] = SERVICE_NAME, config: Annotated[ Optional[ConfluentConfig], - Doc(""" + Doc( + """ Extra configuration for the confluent-kafka-python producer/consumer. See `confluent_kafka.Config `_. - """), + """ + ), ] = None, # publisher args acks: Annotated[ diff --git a/faststream/rabbit/annotations.py b/faststream/rabbit/annotations.py index f32654d2cc..0ae2dc7d3e 100644 --- a/faststream/rabbit/annotations.py +++ b/faststream/rabbit/annotations.py @@ -1,4 +1,3 @@ -from aio_pika import RobustChannel, RobustConnection from typing_extensions import Annotated from faststream.annotations import ContextRepo, Logger, NoCast @@ -14,17 +13,12 @@ "RabbitMessage", "RabbitBroker", "RabbitProducer", - "Channel", - "Connection", ) RabbitMessage = Annotated[RM, Context("message")] RabbitBroker = Annotated[RB, Context("broker")] RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")] -Channel = Annotated[RobustChannel, Context("broker._channel")] -Connection = Annotated[RobustConnection, Context("broker._connection")] - # NOTE: transaction is not for the public usage yet # async def _get_transaction(connection: Connection) -> RabbitTransaction: # async with connection.channel(publisher_confirms=False) as channel: diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 6cb357fef7..013914b9d5 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -1,32 +1,19 @@ import logging from inspect import Parameter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast from urllib.parse import urlparse -from aio_pika import connect_robust from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME from faststream.broker.message import gen_cor_id from faststream.exceptions import NOT_CONNECTED_YET +from faststream.rabbit.broker.connection import ConnectionManager from faststream.rabbit.broker.logging import RabbitLoggingBroker from faststream.rabbit.broker.registrator import RabbitRegistrator from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - RABBIT_REPLY, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RabbitExchange, RabbitQueue from faststream.rabbit.security import parse_security from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.rabbit.utils import build_url @@ -37,8 +24,6 @@ from aio_pika import ( IncomingMessage, - RobustChannel, - RobustConnection, RobustExchange, RobustQueue, ) @@ -48,10 +33,7 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi - from faststream.broker.types import ( - BrokerMiddleware, - CustomCallable, - ) + from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.types import AioPikaSendableMessage from faststream.security import BaseSecurity from faststream.types import AnyDict, Decorator, LoggerProto @@ -67,7 +49,6 @@ class RabbitBroker( _producer: Optional["AioPikaFastProducer"] declarer: Optional[RabbitDeclarer] - _channel: Optional["RobustChannel"] def __init__( self, @@ -213,6 +194,14 @@ def __init__( Iterable["Decorator"], Doc("Any custom decorator to apply to wrapped functions."), ] = (), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: security_args = parse_security(security) @@ -234,6 +223,8 @@ def __init__( # respect ascynapi_url argument scheme builded_asyncapi_url = urlparse(asyncapi_url) self.virtual_host = builded_asyncapi_url.path + self.max_connection_pool_size = max_connection_pool_size + self.max_channel_pool_size = max_channel_pool_size if protocol is None: protocol = builded_asyncapi_url.scheme @@ -273,13 +264,13 @@ def __init__( self.app_id = app_id - self._channel = None self.declarer = None @property def _subscriber_setup_extra(self) -> "AnyDict": return { **super()._subscriber_setup_extra, + "max_consumers": self._max_consumers, "app_id": self.app_id, "virtual_host": self.virtual_host, "declarer": self.declarer, @@ -350,7 +341,7 @@ async def connect( # type: ignore[override] "when mandatory message will be returned" ), ] = Parameter.empty, - ) -> "RobustConnection": + ) -> "ConnectionManager": """Connect broker object to RabbitMQ. To startup subscribers too you should use `broker.start()` after/instead this method. @@ -405,46 +396,37 @@ async def _connect( # type: ignore[override] channel_number: Optional[int], publisher_confirms: bool, on_return_raises: bool, - ) -> "RobustConnection": - connection = cast( - "RobustConnection", - await connect_robust( - url, - timeout=timeout, - ssl_context=ssl_context, - ), - ) - - if self._channel is None: # pragma: no branch - max_consumers = self._max_consumers - channel = self._channel = cast( - "RobustChannel", - await connection.channel( - channel_number=channel_number, - publisher_confirms=publisher_confirms, - on_return_raises=on_return_raises, - ), + ) -> "ConnectionManager": + if self._max_consumers: + c = AsyncAPISubscriber.build_log_context( + None, + RabbitQueue(""), + RabbitExchange(""), ) + self._log(f"Set max consumers to {self._max_consumers}", extra=c) - declarer = self.declarer = RabbitDeclarer(channel) - await declarer.declare_queue(RABBIT_REPLY) + connection_manager = ConnectionManager( + url=url, + timeout=timeout, + ssl_context=ssl_context, + connection_pool_size=self.max_connection_pool_size, + channel_pool_size=self.max_channel_pool_size, + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ) + + if self.declarer is None: + self.declarer = RabbitDeclarer(connection_manager) + if self._producer is None: self._producer = AioPikaFastProducer( - declarer=declarer, + declarer=self.declarer, decoder=self._decoder, parser=self._parser, ) - if max_consumers: - c = AsyncAPISubscriber.build_log_context( - None, - RabbitQueue(""), - RabbitExchange(""), - ) - self._log(f"Set max consumers to {max_consumers}", extra=c) - await channel.set_qos(prefetch_count=int(max_consumers)) - - return connection + return connection_manager async def _close( self, @@ -452,18 +434,12 @@ async def _close( exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: - if self._channel is not None: - if not self._channel.is_closed: - await self._channel.close() - - self._channel = None + if self._connection is not None: + await self._connection.close() self.declarer = None self._producer = None - if self._connection is not None: - await self._connection.close() - await super()._close(exc_type, exc_val, exc_tb) async def start(self) -> None: diff --git a/faststream/rabbit/broker/connection.py b/faststream/rabbit/broker/connection.py new file mode 100644 index 0000000000..b332eb8a42 --- /dev/null +++ b/faststream/rabbit/broker/connection.py @@ -0,0 +1,87 @@ +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, AsyncIterator, Optional, cast + +from aio_pika import connect_robust +from aio_pika.pool import Pool + +if TYPE_CHECKING: + from ssl import SSLContext + + from aio_pika import ( + RobustChannel, + RobustConnection, + ) + from aio_pika.abc import TimeoutType + + +class ConnectionManager: + def __init__( + self, + *, + url: str, + timeout: "TimeoutType", + ssl_context: Optional["SSLContext"], + connection_pool_size: Optional[int], + channel_pool_size: Optional[int], + channel_number: Optional[int], + publisher_confirms: bool, + on_return_raises: bool, + ) -> None: + self._connection_pool: "Pool[RobustConnection]" = Pool( + lambda: connect_robust( + url=url, + timeout=timeout, + ssl_context=ssl_context, + ), + max_size=connection_pool_size, + ) + + self._channel_pool: "Pool[RobustChannel]" = Pool( + lambda: self._get_channel( + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ), + max_size=channel_pool_size, + ) + + async def get_connection(self) -> "RobustConnection": + return await self._connection_pool._get() + + @asynccontextmanager + async def acquire_connection(self) -> AsyncIterator["RobustConnection"]: + async with self._connection_pool.acquire() as connection: + yield connection + + async def get_channel(self) -> "RobustChannel": + return await self._channel_pool._get() + + @asynccontextmanager + async def acquire_channel(self) -> AsyncIterator["RobustChannel"]: + async with self._channel_pool.acquire() as channel: + yield channel + + async def _get_channel( + self, + channel_number: Optional[int] = None, + publisher_confirms: bool = True, + on_return_raises: bool = False, + ) -> "RobustChannel": + async with self.acquire_connection() as connection: + channel = cast( + "RobustChannel", + await connection.channel( + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ), + ) + + return channel + + async def close(self) -> None: + if not self._channel_pool.is_closed: + await self._channel_pool.close() + + if not self._connection_pool.is_closed: + await self._connection_pool.close() diff --git a/faststream/rabbit/broker/logging.py b/faststream/rabbit/broker/logging.py index 21af975df4..3c2a913477 100644 --- a/faststream/rabbit/broker/logging.py +++ b/faststream/rabbit/broker/logging.py @@ -2,16 +2,17 @@ from inspect import Parameter from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union -from aio_pika import IncomingMessage, RobustConnection +from aio_pika import IncomingMessage from faststream.broker.core.usecase import BrokerUsecase from faststream.log.logging import get_broker_logger +from faststream.rabbit.broker.connection import ConnectionManager if TYPE_CHECKING: from faststream.types import LoggerProto -class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, RobustConnection]): +class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, ConnectionManager]): """A class that extends the LoggingMixin class and adds additional functionality for logging RabbitMQ related information.""" _max_queue_len: int diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index d0445badfb..e965e45ca7 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -26,10 +26,7 @@ from faststream.broker.utils import default_filter from faststream.rabbit.broker.broker import RabbitBroker as RB from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher -from faststream.rabbit.schemas import ( - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber if TYPE_CHECKING: @@ -414,6 +411,14 @@ def __init__( """ ), ] = Default(generate_unique_id), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: super().__init__( url, @@ -424,6 +429,8 @@ def __init__( client_properties=client_properties, timeout=timeout, max_consumers=max_consumers, + max_connection_pool_size=max_connection_pool_size, + max_channel_pool_size=max_channel_pool_size, app_id=app_id, graceful_timeout=graceful_timeout, decoder=decoder, diff --git a/faststream/rabbit/helpers/declarer.py b/faststream/rabbit/helpers/declarer.py index 57c21a3a78..a5f38f8deb 100644 --- a/faststream/rabbit/helpers/declarer.py +++ b/faststream/rabbit/helpers/declarer.py @@ -1,20 +1,22 @@ -from typing import TYPE_CHECKING, Dict, cast +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Dict, Optional, cast if TYPE_CHECKING: import aio_pika + from faststream.rabbit.broker.connection import ConnectionManager from faststream.rabbit.schemas import RabbitExchange, RabbitQueue class RabbitDeclarer: """An utility class to declare RabbitMQ queues and exchanges.""" - __channel: "aio_pika.RobustChannel" + __connection_manager: "ConnectionManager" __queues: Dict["RabbitQueue", "aio_pika.RobustQueue"] __exchanges: Dict["RabbitExchange", "aio_pika.RobustExchange"] - def __init__(self, channel: "aio_pika.RobustChannel") -> None: - self.__channel = channel + def __init__(self, connection_manager: "ConnectionManager") -> None: + self.__connection_manager = connection_manager self.__queues = {} self.__exchanges = {} @@ -22,58 +24,81 @@ async def declare_queue( self, queue: "RabbitQueue", passive: bool = False, + *, + channel: Optional["aio_pika.RobustChannel"] = None, ) -> "aio_pika.RobustQueue": """Declare a queue.""" - if (q := self.__queues.get(queue)) is None: - self.__queues[queue] = q = cast( - "aio_pika.RobustQueue", - await self.__channel.declare_queue( - name=queue.name, - durable=queue.durable, - exclusive=queue.exclusive, - passive=passive or queue.passive, - auto_delete=queue.auto_delete, - arguments=queue.arguments, - timeout=queue.timeout, - robust=queue.robust, - ), - ) - - return q + if (queue_obj := self.__queues.get(queue)) is None: + async with AsyncExitStack() as stack: + if channel is None: + channel = await stack.enter_async_context( + self.__connection_manager.acquire_channel() + ) + + self.__queues[queue] = queue_obj = cast( + "aio_pika.RobustQueue", + await channel.declare_queue( + name=queue.name, + durable=queue.durable, + exclusive=queue.exclusive, + passive=passive or queue.passive, + auto_delete=queue.auto_delete, + arguments=queue.arguments, + timeout=queue.timeout, + robust=queue.robust, + ), + ) + + return queue_obj # type: ignore[return-value] async def declare_exchange( self, exchange: "RabbitExchange", passive: bool = False, + *, + channel: Optional["aio_pika.RobustChannel"] = None, ) -> "aio_pika.RobustExchange": """Declare an exchange, parent exchanges and bind them each other.""" - if not exchange.name: - return self.__channel.default_exchange - - if (exch := self.__exchanges.get(exchange)) is None: - self.__exchanges[exchange] = exch = cast( - "aio_pika.RobustExchange", - await self.__channel.declare_exchange( - name=exchange.name, - type=exchange.type.value, - durable=exchange.durable, - auto_delete=exchange.auto_delete, - passive=passive or exchange.passive, - arguments=exchange.arguments, - timeout=exchange.timeout, - robust=exchange.robust, - internal=False, # deprecated RMQ option - ), - ) - - if exchange.bind_to is not None: - parent = await self.declare_exchange(exchange.bind_to) - await exch.bind( - exchange=parent, - routing_key=exchange.routing, - arguments=exchange.bind_arguments, - timeout=exchange.timeout, - robust=exchange.robust, + if exch := self.__exchanges.get(exchange): + return exch + + async with AsyncExitStack() as stack: + if channel is None: + channel = await stack.enter_async_context( + self.__connection_manager.acquire_channel() ) - return exch + if not exchange.name: + return channel.default_exchange + + else: + self.__exchanges[exchange] = exch = cast( + "aio_pika.RobustExchange", + await channel.declare_exchange( + name=exchange.name, + type=exchange.type.value, + durable=exchange.durable, + auto_delete=exchange.auto_delete, + passive=passive or exchange.passive, + arguments=exchange.arguments, + timeout=exchange.timeout, + robust=exchange.robust, + internal=False, # deprecated RMQ option + ), + ) + + if exchange.bind_to is not None: + parent = await self.declare_exchange( + exchange.bind_to, + channel=channel, + ) + + await exch.bind( + exchange=parent, + routing_key=exchange.routing, + arguments=exchange.bind_arguments, + timeout=exchange.timeout, + robust=exchange.robust, + ) + + return exch # type: ignore[return-value] diff --git a/faststream/rabbit/publisher/producer.py b/faststream/rabbit/publisher/producer.py index f7a4013bab..4d1a6acd23 100644 --- a/faststream/rabbit/publisher/producer.py +++ b/faststream/rabbit/publisher/producer.py @@ -23,7 +23,7 @@ from types import TracebackType import aiormq - from aio_pika import IncomingMessage, RobustQueue + from aio_pika import IncomingMessage, RobustChannel, RobustQueue from aio_pika.abc import DateType, HeadersType, TimeoutType from anyio.streams.memory import MemoryObjectReceiveStream @@ -88,15 +88,19 @@ async def publish( # type: ignore[override] context: AsyncContextManager[ Optional[MemoryObjectReceiveStream[IncomingMessage]] ] + channel: Optional["RobustChannel"] + if rpc: if reply_to is not None: raise WRONG_PUBLISH_ARGS - context = _RPCCallback( - self._rpc_lock, - await self.declarer.declare_queue(RABBIT_REPLY), - ) + rmq_queue = await self.declarer.declare_queue(RABBIT_REPLY) + channel = cast("RobustChannel", rmq_queue.channel) + context = _RPCCallback(self._rpc_lock, rmq_queue) + reply_to = RABBIT_REPLY.name + else: + channel = None context = fake_context() async with context as response_queue: @@ -108,7 +112,7 @@ async def publish( # type: ignore[override] immediate=immediate, timeout=timeout, persist=persist, - reply_to=reply_to if response_queue is None else RABBIT_REPLY.name, + reply_to=reply_to, headers=headers, content_type=content_type, content_encoding=content_encoding, @@ -120,6 +124,7 @@ async def publish( # type: ignore[override] message_type=message_type, user_id=user_id, app_id=app_id, + channel=channel, ) if response_queue is None: @@ -157,6 +162,7 @@ async def _publish( message_type: Optional[str], user_id: Optional[str], app_id: Optional[str], + channel: Optional["RobustChannel"], ) -> Union["aiormq.abc.ConfirmationFrameType", "SendableMessage"]: """Publish a message to a RabbitMQ exchange.""" message = AioPikaParser.encode_message( @@ -179,6 +185,7 @@ async def _publish( exchange_obj = await self.declarer.declare_exchange( exchange=RabbitExchange.validate(exchange), passive=True, + channel=channel, ) return await exchange_obj.publish( diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 67421df2da..658e60c987 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -100,6 +100,7 @@ def setup( # type: ignore[override] self, *, app_id: Optional[str], + max_consumers: Optional[int], virtual_host: str, declarer: "RabbitDeclarer", # basic args @@ -119,6 +120,7 @@ def setup( # type: ignore[override] self.app_id = app_id self.virtual_host = virtual_host self.declarer = declarer + self.__max_consumers = max_consumers super().setup( logger=logger, @@ -141,6 +143,9 @@ async def start(self) -> None: self._queue_obj = queue = await self.declarer.declare_queue(self.queue) + if self.__max_consumers is not None: + await queue.channel.set_qos(prefetch_count=self.__max_consumers) + if ( self.exchange is not None and not queue.passive # queue just getted from RMQ diff --git a/faststream/rabbit/testing.py b/faststream/rabbit/testing.py index 3d3a274418..2722116fe9 100644 --- a/faststream/rabbit/testing.py +++ b/faststream/rabbit/testing.py @@ -13,11 +13,7 @@ from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - ExchangeType, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.testing.broker import TestBroker, call_handler @@ -35,7 +31,6 @@ class TestRabbitBroker(TestBroker[RabbitBroker]): @classmethod def _patch_test_broker(cls, broker: RabbitBroker) -> None: - broker._channel = AsyncMock() broker.declarer = AsyncMock() super()._patch_test_broker(broker) diff --git a/tests/brokers/rabbit/specific/test_declare.py b/tests/brokers/rabbit/specific/test_declare.py index aed6824f3e..eb16dc648c 100644 --- a/tests/brokers/rabbit/specific/test_declare.py +++ b/tests/brokers/rabbit/specific/test_declare.py @@ -1,59 +1,73 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator +from unittest.mock import AsyncMock + import pytest from faststream.rabbit import RabbitBroker, RabbitExchange, RabbitQueue from faststream.rabbit.helpers.declarer import RabbitDeclarer +class FakeConnectionManage: + def __init__(self) -> None: + self.mock = AsyncMock() + + @asynccontextmanager + async def acquire_channel(self) -> AsyncIterator[AsyncMock]: + yield self.mock + + @pytest.mark.asyncio() -async def test_declare_queue(async_mock, queue: str): - declarer = RabbitDeclarer(async_mock) +async def test_declare_queue(queue: str): + manager = FakeConnectionManage() + declarer = RabbitDeclarer(manager) q1 = await declarer.declare_queue(RabbitQueue(queue)) q2 = await declarer.declare_queue(RabbitQueue(queue)) assert q1 is q2 - async_mock.declare_queue.assert_awaited_once() + manager.mock.declare_queue.assert_awaited_once() @pytest.mark.asyncio() async def test_declare_exchange( - async_mock, queue: str, ): - declarer = RabbitDeclarer(async_mock) + manager = FakeConnectionManage() + declarer = RabbitDeclarer(manager) ex1 = await declarer.declare_exchange(RabbitExchange(queue)) ex2 = await declarer.declare_exchange(RabbitExchange(queue)) assert ex1 is ex2 - async_mock.declare_exchange.assert_awaited_once() + manager.mock.declare_exchange.assert_awaited_once() @pytest.mark.asyncio() async def test_declare_nested_exchange_cash_nested( - async_mock, queue: str, ): - declarer = RabbitDeclarer(async_mock) + manager = FakeConnectionManage() + declarer = RabbitDeclarer(manager) exchange = RabbitExchange(queue) await declarer.declare_exchange(RabbitExchange(queue + "1", bind_to=exchange)) - assert async_mock.declare_exchange.await_count == 2 + assert manager.mock.declare_exchange.await_count == 2 await declarer.declare_exchange(exchange) - assert async_mock.declare_exchange.await_count == 2 + assert manager.mock.declare_exchange.await_count == 2 @pytest.mark.asyncio() async def test_publisher_declare( - async_mock, queue: str, ): - declarer = RabbitDeclarer(async_mock) + manager = FakeConnectionManage() + declarer = RabbitDeclarer(manager) broker = RabbitBroker() - broker._connection = async_mock + broker._connection = manager.mock broker.declarer = declarer @broker.publisher(queue, queue) @@ -61,5 +75,5 @@ async def f(): ... await broker.start() - assert not async_mock.declare_queue.await_count - async_mock.declare_exchange.assert_awaited_once() + assert not manager.mock.declare_queue.await_count + manager.mock.declare_exchange.assert_awaited_once() diff --git a/tests/brokers/rabbit/specific/test_init.py b/tests/brokers/rabbit/specific/test_init.py deleted file mode 100644 index e87b71b466..0000000000 --- a/tests/brokers/rabbit/specific/test_init.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest - -from faststream.rabbit import RabbitBroker - - -@pytest.mark.asyncio() -@pytest.mark.rabbit() -async def test_set_max(): - broker = RabbitBroker(logger=None, max_consumers=10) - await broker.start() - assert broker._channel._prefetch_count == 10 - await broker.close() diff --git a/tests/brokers/rabbit/specific/test_qos.py b/tests/brokers/rabbit/specific/test_qos.py new file mode 100644 index 0000000000..890a2fa30a --- /dev/null +++ b/tests/brokers/rabbit/specific/test_qos.py @@ -0,0 +1,20 @@ +import pytest + +from faststream.rabbit import RabbitBroker, RabbitQueue + + +@pytest.mark.asyncio() +@pytest.mark.rabbit() +async def test_set_max(): + queue = RabbitQueue("test") + + broker = RabbitBroker(logger=None, max_consumers=10) + + @broker.subscriber(queue) + async def handler(): ... + + async with broker: + await broker.start() + + queue = await broker.declare_queue(queue) + assert queue.channel._prefetch_count == 10