From cd3be80373ab734083091ad86fd4e43126f4218f Mon Sep 17 00:00:00 2001 From: Carlos Bellino Date: Thu, 9 Nov 2023 10:55:13 +0100 Subject: [PATCH 1/3] draft: SQS integration. --- faststream/sqs/__init__.py | 24 ++++ faststream/sqs/annotations.py | 19 +++ faststream/sqs/asyncapi.py | 8 ++ faststream/sqs/broker.py | 3 + faststream/sqs/handler.py | 10 ++ faststream/sqs/message.py | 26 ++++ faststream/sqs/producer.py | 3 + faststream/sqs/router.py | 6 + faststream/sqs/shared/__init__.py | 0 faststream/sqs/shared/logging.py | 56 +++++++++ faststream/sqs/shared/router.py | 40 ++++++ faststream/sqs/shared/schemas.py | 198 ++++++++++++++++++++++++++++++ faststream/sqs/test.py | 7 ++ 13 files changed, 400 insertions(+) create mode 100644 faststream/sqs/__init__.py create mode 100644 faststream/sqs/annotations.py create mode 100644 faststream/sqs/asyncapi.py create mode 100644 faststream/sqs/broker.py create mode 100644 faststream/sqs/handler.py create mode 100644 faststream/sqs/message.py create mode 100644 faststream/sqs/producer.py create mode 100644 faststream/sqs/router.py create mode 100644 faststream/sqs/shared/__init__.py create mode 100644 faststream/sqs/shared/logging.py create mode 100644 faststream/sqs/shared/router.py create mode 100644 faststream/sqs/shared/schemas.py create mode 100644 faststream/sqs/test.py diff --git a/faststream/sqs/__init__.py b/faststream/sqs/__init__.py new file mode 100644 index 0000000000..2beebe9a2a --- /dev/null +++ b/faststream/sqs/__init__.py @@ -0,0 +1,24 @@ +from faststream.broker.test import TestApp +from faststream.sqs.annotations import SQSBroker, SQSMessage, SQSProducer +from faststream.sqs.router import SQSRouter +from faststream.sqs.shared.router import SQSRoute +from faststream.sqs.shared.schemas import ( + FifoQueue, + RedriveAllowPolicy, + RedrivePolicy, + SQSQueue, +) +from faststream.sqs.test import TestSQSBroker + +__all__ = ( + "FifoQueue", + "RedriveAllowPolicy", + "RedrivePolicy", + "SQSBroker", + "SQSMessage", + "SQSQueue", + "SQSRouter", + "SQSRoute", + "TestApp", + "TestSQSBroker", +) diff --git a/faststream/sqs/annotations.py b/faststream/sqs/annotations.py new file mode 100644 index 0000000000..1d06598fd8 --- /dev/null +++ b/faststream/sqs/annotations.py @@ -0,0 +1,19 @@ +from faststream._compat import Annotated +from faststream.annotations import ContextRepo, Logger, NoCast +from faststream.sqs.broker import SQSBroker as SB # NOQA +from faststream.sqs.message import SQSMessage as SM # NOQA +from faststream.sqs.producer import SQSFastProducer +from faststream.utils.context import Context + +__all__ = ( + "Logger", + "ContextRepo", + "NoCast", + "SQSBroker", + "SQSMessage", + "SQSProducer", +) + +SQSBroker = Annotated[SB, Context("broker")] +SQSMessage = Annotated[SM, Context("message")] +SQSProducer = Annotated[SQSFastProducer, Context("broker._producer")] diff --git a/faststream/sqs/asyncapi.py b/faststream/sqs/asyncapi.py new file mode 100644 index 0000000000..b94c0e0702 --- /dev/null +++ b/faststream/sqs/asyncapi.py @@ -0,0 +1,8 @@ +class Publisher: + # TODO + pass + + +class Handler: + # TODO + pass diff --git a/faststream/sqs/broker.py b/faststream/sqs/broker.py new file mode 100644 index 0000000000..8b5a0a2925 --- /dev/null +++ b/faststream/sqs/broker.py @@ -0,0 +1,3 @@ +class SQSBroker: + # TODO + pass diff --git a/faststream/sqs/handler.py b/faststream/sqs/handler.py new file mode 100644 index 0000000000..a56f2aeb59 --- /dev/null +++ b/faststream/sqs/handler.py @@ -0,0 +1,10 @@ +from aiobotocore.client import AioBaseClient + +from faststream.broker.handler import AsyncHandler +from faststream.types import AnyDict + + +class LogicSQSHandler(AsyncHandler[AnyDict]): + async def start(self, client: AioBaseClient) -> None: + # TODO check "start" method on broker + pass diff --git a/faststream/sqs/message.py b/faststream/sqs/message.py new file mode 100644 index 0000000000..4c828e74c9 --- /dev/null +++ b/faststream/sqs/message.py @@ -0,0 +1,26 @@ +from typing import Any + +from faststream.broker.message import StreamMessage +from faststream.types import AnyDict + + +class SQSMessage(StreamMessage[AnyDict]): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.commited = False + + async def ack(self, **kwargs: Any) -> None: + # TODO: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html + self.commited = True + + async def nack(self, **kwargs: Any) -> None: + # TODO + self.commited = True + + async def reject(self, **kwargs: Any) -> None: + # TODO + self.commited = True diff --git a/faststream/sqs/producer.py b/faststream/sqs/producer.py new file mode 100644 index 0000000000..1831d34390 --- /dev/null +++ b/faststream/sqs/producer.py @@ -0,0 +1,3 @@ +class SQSFastProducer: + # TODO + pass diff --git a/faststream/sqs/router.py b/faststream/sqs/router.py new file mode 100644 index 0000000000..10cc5498e8 --- /dev/null +++ b/faststream/sqs/router.py @@ -0,0 +1,6 @@ +from faststream.sqs.shared.router import SQSRouter as BaseRouter + + +class SQSRouter(BaseRouter): + # TODO + pass diff --git a/faststream/sqs/shared/__init__.py b/faststream/sqs/shared/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/faststream/sqs/shared/logging.py b/faststream/sqs/shared/logging.py new file mode 100644 index 0000000000..bf306daf2b --- /dev/null +++ b/faststream/sqs/shared/logging.py @@ -0,0 +1,56 @@ +import logging +from typing import Any, Optional + +from faststream._compat import override +from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.message import StreamMessage +from faststream.log import access_logger +from faststream.types import AnyDict + + +class SQSLoggingMixin(LoggingMixin): + _max_queue_len: int + + def __init__( + self, + *args: Any, + logger: Optional[logging.Logger] = access_logger, + log_level: int = logging.INFO, + log_fmt: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + **kwargs, + ) + self._max_queue_len = 4 + + @override + def _get_log_context( # type: ignore[override] + self, + message: Optional[StreamMessage[Any]], + queue: str = "", + ) -> AnyDict: + return { + "queue": queue, + **super()._get_log_context(message), + } + + @property + def fmt(self) -> str: + return self._fmt or ( + "%(asctime)s %(levelname)s - " + f"%(queue)-{self._max_queue_len}s | " + "%(message_id)-10s " + "- %(message)s" + ) + + def _setup_log_context( + self, + queue: Optional[str] = None, + ) -> None: + if queue is not None: + self._max_queue_len = max((self._max_queue_len, len(queue))) diff --git a/faststream/sqs/shared/router.py b/faststream/sqs/shared/router.py new file mode 100644 index 0000000000..8805a4fd6b --- /dev/null +++ b/faststream/sqs/shared/router.py @@ -0,0 +1,40 @@ +from typing import Callable, Union + +from faststream._compat import model_copy, override +from faststream.broker.router import BrokerRoute as SQSRoute +from faststream.broker.router import BrokerRouter +from faststream.broker.types import P_HandlerParams, T_HandlerReturn +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.sqs.shared.schemas import SQSQueue +from faststream.types import Any, AnyDict, SendableMessage, Sequence + + +class SQSRouter(BrokerRouter[AnyDict]): + def __init__( + self, + prefix: str = "", + handlers: Sequence[SQSRoute[AnyDict, SendableMessage]] = (), + **kwargs: Any, + ): + for h in handlers: + if (q := h.kwargs.pop("queue", None)) is None: + q, h.args = h.args[0], h.args[1:] + queue = SQSQueue.validate(q) + new_q = model_copy(queue, update={"name": prefix + queue.name}) + h.args = (new_q, *h.args) + + super().__init__(prefix, handlers, **kwargs) + + @override + def subscriber( # type: ignore[override] + self, + queue: Union[str, SQSQueue], + *args: Any, + **kwargs: AnyDict, + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[AnyDict, P_HandlerParams, T_HandlerReturn], + ]: + q = SQSQueue.validate(queue) + new_q = model_copy(q, update={"name": self.prefix + q.name}) + return self._wrap_subscriber(new_q, *args, **kwargs) diff --git a/faststream/sqs/shared/schemas.py b/faststream/sqs/shared/schemas.py new file mode 100644 index 0000000000..ce2b6821b3 --- /dev/null +++ b/faststream/sqs/shared/schemas.py @@ -0,0 +1,198 @@ +from typing import Any, Dict, Optional, Sequence + +from pydantic import BaseModel, Field, PositiveInt +from typing_extensions import Literal + +from faststream._compat import PYDANTIC_V2 +from faststream.broker.schemas import NameRequired + + +class RedrivePolicy(BaseModel): + """SQS Queue RedrivePolicy attribute details""" + + dead_letter_target: str = Field( + default="", + alias="deadLetterTargetArn", + ) + max_receive_count: PositiveInt = Field( + default=10, + alias="deadLetterTargetArn", + ) + + +class RedriveAllowPolicy(BaseModel): + """SQS Queue RedriveAllowPolicy attribute details""" + + redrive_permission: Literal["allowAll", "denyAll", "byQueue"] = Field( + default="allowAll", + alias="redrivePermission", + ) + source_queue_arns: Sequence[str] = Field( + default_factory=tuple, + alias="sourceQueueArns", + max_length=10, + ) + + +class SQSQueue(NameRequired): + """SQS Basic Queue initialization attributes""" + + fifo: bool = Field( + default=False, + alias="FifoQueue", + ) + + delay_seconds: int = Field( + default=0, + # alias="DelaySeconds", + ge=0, + le=900, + ) + max_message_size: int = Field( + default=262_144, + alias="MaximumMessageSize", + ge=1024, + le=262_144, + ) + retention_period_sec: int = Field( + 345_600, + alias="MessageRetentionPeriod", + ge=60, + le=1_209_600, + ) + receive_wait_time_sec: int = Field( + default=0, + alias="ReceiveMessageWaitTimeSeconds", + ge=0, + le=20, + ) + visibility_timeout_sec: int = Field( + default=30, + alias="VisibilityTimeout", + ge=0, + le=43_200, + ) + redrive_policy: RedrivePolicy = Field( + default_factory=RedrivePolicy, + alias="RedrivePolicy", + ) + redrive_allow_policy: RedriveAllowPolicy = Field( + default_factory=RedrivePolicy, + alias="RedriveAllowPolicy", + ) + + kms_master_key_id: str = Field(default="", alias="KmsMasterKeyId") + kms_data_key_reuse_period_sec: int = Field( + default=300, + alias="KmsDataKeyReusePeriodSeconds", + ge=60, + le=86_400, + ) + sse_enabled: bool = Field( + default=False, + alias="SqsManagedSseEnabled", + ) + tags: Dict[str, str] = Field( + default_factory=dict, + ) + + def __init__( + self, + name: str, + fifo: bool = False, + delay_seconds: int = 0, + max_message_size: int = 262_144, + visibility_timeout_sec: int = 0, + receive_wait_time_sec: int = 0, + retention_period_sec: int = 345_600, + redrive_policy: Optional[RedrivePolicy] = None, + redrive_allow_policy: Optional[RedriveAllowPolicy] = None, + kms_master_key_id: str = "", + kms_data_key_reuse_period_sec: int = 300, + sse_enabled: bool = False, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + fifo=fifo, + visibility_timeout_sec=visibility_timeout_sec, + receive_wait_time_sec=receive_wait_time_sec, + retention_period_sec=retention_period_sec, + max_message_size=max_message_size, + delay_seconds=delay_seconds, + redrive_policy=redrive_policy or RedrivePolicy(), + redrive_allow_policy=redrive_allow_policy or RedriveAllowPolicy(), + kms_master_key_id=kms_master_key_id, + kms_data_key_reuse_period_sec=kms_data_key_reuse_period_sec, + sse_enabled=sse_enabled, + tags=tags or {}, + **kwargs, + ) + + if PYDANTIC_V2: + model_config = {"arbitrary_types_allowed": True} + else: + + class Config: + arbitrary_types_allowed = True + + +class FifoQueue(SQSQueue): + """SQS FIFO Queue initialization attributes""" + + fifo: bool = Field( + default=True, + alias="FifoQueue", + ) + content_based_deduplication: bool = Field( + default=True, + alias="ContentBasedDeduplication", + ) + deduplication_scope: Optional[Literal["messageGroup", "queue"]] = Field( + default=None, + alias="DeduplicationScope", + ) + + # TODO: pydantic validation and test + # allow perMessageGroup only for messageGroup deduplication_scope + throughput_limit: Optional[Literal["perMessageGroup", "perQueue"]] = Field( + default=None, + alias="FifoThroughputLimit", + ) + + def __init__( + self, + name: str, + fifo: bool = True, + delay_seconds: int = 0, + max_message_size: int = 262_144, + visibility_timeout_sec: int = 0, + receive_wait_time_sec: int = 0, + retention_period_sec: int = 345_600, + content_based_deduplication: bool = True, + deduplication_scope: Optional[Literal["messageGroup", "queue"]] = None, + throughput_limit: Optional[Literal["perMessageGroup", "perQueue"]] = None, + redrive_policy: Optional[RedrivePolicy] = None, + redrive_allow_policy: Optional[RedriveAllowPolicy] = None, + kms_data_key_reuse_period_sec: int = 300, + sse_enabled: bool = False, + tags: Optional[Dict[str, str]] = None, + ): + super().__init__( + name=name, + fifo=fifo, + visibility_timeout_sec=visibility_timeout_sec, + receive_wait_time_sec=receive_wait_time_sec, + content_based_deduplication=content_based_deduplication, + retention_period_sec=retention_period_sec, + max_message_size=max_message_size, + delay_seconds=delay_seconds, + redrive_policy=redrive_policy or RedrivePolicy(), + redrive_allow_policy=redrive_allow_policy or RedriveAllowPolicy(), + kms_data_key_reuse_period_sec=kms_data_key_reuse_period_sec, + sse_enabled=sse_enabled, + deduplication_scope=deduplication_scope, + throughput_limit=throughput_limit, + tags=tags or {}, + ) diff --git a/faststream/sqs/test.py b/faststream/sqs/test.py new file mode 100644 index 0000000000..17a5bddcde --- /dev/null +++ b/faststream/sqs/test.py @@ -0,0 +1,7 @@ +from faststream.broker.test import TestBroker +from faststream.sqs.broker import SQSBroker + + +class TestSQSBroker(TestBroker[SQSBroker]): + # TODO + pass From 5cbd3e53c8523ae1cf981b1469f1defc7c39f7ac Mon Sep 17 00:00:00 2001 From: Carlos Bellino Date: Thu, 9 Nov 2023 16:54:57 +0100 Subject: [PATCH 2/3] WIP: Broker and Handler template added. --- faststream/sqs/__init__.py | 2 +- faststream/sqs/annotations.py | 6 +++ faststream/sqs/asyncapi.py | 4 +- faststream/sqs/broker.py | 98 +++++++++++++++++++++++++++++++++-- faststream/sqs/handler.py | 61 +++++++++++++++++++++- 5 files changed, 164 insertions(+), 7 deletions(-) diff --git a/faststream/sqs/__init__.py b/faststream/sqs/__init__.py index 2beebe9a2a..16ca43cc96 100644 --- a/faststream/sqs/__init__.py +++ b/faststream/sqs/__init__.py @@ -1,5 +1,5 @@ from faststream.broker.test import TestApp -from faststream.sqs.annotations import SQSBroker, SQSMessage, SQSProducer +from faststream.sqs.annotations import SQSBroker, SQSMessage from faststream.sqs.router import SQSRouter from faststream.sqs.shared.router import SQSRoute from faststream.sqs.shared.schemas import ( diff --git a/faststream/sqs/annotations.py b/faststream/sqs/annotations.py index 1d06598fd8..c6c754b5a2 100644 --- a/faststream/sqs/annotations.py +++ b/faststream/sqs/annotations.py @@ -1,3 +1,5 @@ +from aiobotocore.client import AioBaseClient + from faststream._compat import Annotated from faststream.annotations import ContextRepo, Logger, NoCast from faststream.sqs.broker import SQSBroker as SB # NOQA @@ -12,8 +14,12 @@ "SQSBroker", "SQSMessage", "SQSProducer", + "client", + "queue_url", ) SQSBroker = Annotated[SB, Context("broker")] SQSMessage = Annotated[SM, Context("message")] SQSProducer = Annotated[SQSFastProducer, Context("broker._producer")] +client = Annotated[AioBaseClient, Context("client")] +queue_url = Annotated[str, Context("queue_url")] diff --git a/faststream/sqs/asyncapi.py b/faststream/sqs/asyncapi.py index b94c0e0702..6321a28903 100644 --- a/faststream/sqs/asyncapi.py +++ b/faststream/sqs/asyncapi.py @@ -1,8 +1,8 @@ -class Publisher: +class Handler: # TODO pass -class Handler: +class Publisher: # TODO pass diff --git a/faststream/sqs/broker.py b/faststream/sqs/broker.py index 8b5a0a2925..ed1c06324c 100644 --- a/faststream/sqs/broker.py +++ b/faststream/sqs/broker.py @@ -1,3 +1,95 @@ -class SQSBroker: - # TODO - pass +from types import TracebackType +from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Type, Union + +from aiobotocore.client import AioBaseClient +from fast_depends.dependencies import Depends + +from faststream import BaseMiddleware +from faststream.broker.core.asyncronous import BrokerAsyncUsecase, default_filter +from faststream.broker.message import StreamMessage +from faststream.broker.publisher import BasePublisher +from faststream.broker.push_back_watcher import BaseWatcher +from faststream.broker.types import ( + CustomDecoder, + CustomParser, + Filter, + MsgType, + P_HandlerParams, + T_HandlerReturn, + WrappedReturn, +) +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.sqs.asyncapi import Handler, Publisher +from faststream.sqs.producer import SQSFastProducer +from faststream.sqs.shared.logging import SQSLoggingMixin +from faststream.types import AnyDict, SendableMessage + + +class SQSBroker( + SQSLoggingMixin, + BrokerAsyncUsecase[AnyDict, AioBaseClient], +): + handlers: Dict[str, Handler] # type: ignore[assignment] + _publishers: Dict[str, Publisher] # type: ignore[assignment] + _producer: Optional[SQSFastProducer] + + async def start(self) -> None: + pass + + async def _connect(self, **kwargs: Any) -> AioBaseClient: + pass + + async def _close( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exec_tb: Optional[TracebackType] = None, + ) -> None: + pass + + def _process_message( + self, + func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], + watcher: BaseWatcher, + ) -> Callable[[StreamMessage[MsgType]], Awaitable[WrappedReturn[T_HandlerReturn]],]: + pass + + async def publish( + self, + message: SendableMessage, + *args: Any, + reply_to: str = "", + rpc: bool = False, + rpc_timeout: Optional[float] = None, + raise_timeout: bool = False, + **kwargs: Any, + ) -> Optional[SendableMessage]: + pass + + def subscriber( + self, + *broker_args: Any, + retry: Union[bool, int] = False, + dependencies: Sequence[Depends] = (), + decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, + parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, + middlewares: Optional[Sequence[Callable[[MsgType], BaseMiddleware]]] = None, + filter: Filter[StreamMessage[MsgType]] = default_filter, + _raw: bool = False, + _get_dependant: Optional[Any] = None, + **broker_kwargs: Any, + ) -> Callable[ + [ + Union[ + Callable[P_HandlerParams, T_HandlerReturn], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ] + ], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ]: + pass + + def publisher( + self, key: Any, publisher: BasePublisher[MsgType] + ) -> BasePublisher[MsgType]: + pass diff --git a/faststream/sqs/handler.py b/faststream/sqs/handler.py index a56f2aeb59..148fe33f24 100644 --- a/faststream/sqs/handler.py +++ b/faststream/sqs/handler.py @@ -1,10 +1,69 @@ +import asyncio +import logging +from typing import Any, NoReturn, Optional + +import anyio from aiobotocore.client import AioBaseClient +from typing_extensions import TypeAlias from faststream.broker.handler import AsyncHandler +from faststream.sqs.shared.schemas import SQSQueue from faststream.types import AnyDict +from faststream.utils.context import context + +QueueUrl: TypeAlias = str class LogicSQSHandler(AsyncHandler[AnyDict]): + queue: SQSQueue + consumer_params: AnyDict + task: Optional["asyncio.Task[Any]"] = None + + async def _consume(self, queue_url: str) -> NoReturn: + c = self._get_log_context(None, self.queue.name) + + connected = True + with context.scope("queue_url", queue_url): + while True: + try: + if connected is False: + await self.create_queue(self.queue) + + r = await self._connection.receive_message( + QueueUrl=queue_url, + **self.consumer_params, + ) + + except Exception as e: + if connected is True: + self._log(e, logging.WARNING, c, exc_info=e) + self._queues.pop(self.queue.name) + connected = False + + await anyio.sleep(5) + + else: + if connected is False: + self._log("Connection established", logging.INFO, c) + connected = True + + messages = r.get("Messages", []) + for msg in messages: + try: + await self.callback(msg, True) + except Exception: + has_trash_messages = True + else: + has_trash_messages = False + + if has_trash_messages is True: + await anyio.sleep( + self.consumer_params.get("WaitTimeSeconds", 1.0) + ) + async def start(self, client: AioBaseClient) -> None: - # TODO check "start" method on broker + url = await self.create_queue(self.queue) + self.task = asyncio.create_task(self._consume(url)) + + async def close(self) -> None: pass From 69e23da508342984112750d5ba935cf543d0fd0e Mon Sep 17 00:00:00 2001 From: Carlos Bellino Date: Mon, 13 Nov 2023 16:18:27 +0100 Subject: [PATCH 3/3] WIP: Broker and Handler template added. --- faststream/sqs/asyncapi.py | 42 +++++++- faststream/sqs/broker.py | 197 ++++++++++++++++++++++++++++++++----- faststream/sqs/handler.py | 55 +---------- 3 files changed, 219 insertions(+), 75 deletions(-) diff --git a/faststream/sqs/asyncapi.py b/faststream/sqs/asyncapi.py index 6321a28903..8b9fb0151e 100644 --- a/faststream/sqs/asyncapi.py +++ b/faststream/sqs/asyncapi.py @@ -1,6 +1,42 @@ -class Handler: - # TODO - pass +from typing import Dict + +from faststream._compat import model_to_dict +from faststream.asyncapi.schema import ( + Channel, + ChannelBinding, + CorrelationId, + Message, + Operation, +) +from faststream.asyncapi.schema.bindings import sqs +from faststream.asyncapi.utils import resolve_payloads +from faststream.sqs.handler import LogicSQSHandler + + +class Handler(LogicSQSHandler): + def schema(self) -> Dict[str, Channel]: + payloads = self.get_payloads() + handler_name = self._title or f"{self.queue.name}:{self.call_name}" + + return { + handler_name: Channel( + description=self.description, + subscribe=Operation( + message=Message( + title=f"{handler_name}:Message", + correlationId=CorrelationId( + location="$message.header#/correlation_id" + ), + payload=resolve_payloads(payloads), + ), + ), + bindings=ChannelBinding( + sqs=sqs.ChannelBinding( + queue=model_to_dict(self.queue, include={"name", "fifo"}), + ) + ), + ), + } class Publisher: diff --git a/faststream/sqs/broker.py b/faststream/sqs/broker.py index ed1c06324c..67317a5509 100644 --- a/faststream/sqs/broker.py +++ b/faststream/sqs/broker.py @@ -1,15 +1,32 @@ +import asyncio +import logging +from functools import partial, wraps from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + NoReturn, + Optional, + Sequence, + Type, + Union, +) +import anyio from aiobotocore.client import AioBaseClient +from aiobotocore.session import get_session from fast_depends.dependencies import Depends from faststream import BaseMiddleware +from faststream._compat import model_to_dict from faststream.broker.core.asyncronous import BrokerAsyncUsecase, default_filter from faststream.broker.message import StreamMessage from faststream.broker.publisher import BasePublisher -from faststream.broker.push_back_watcher import BaseWatcher +from faststream.broker.push_back_watcher import BaseWatcher, WatcherContext from faststream.broker.types import ( + AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, @@ -18,11 +35,14 @@ T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper +from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper +from faststream.sqs import SQSQueue from faststream.sqs.asyncapi import Handler, Publisher +from faststream.sqs.handler import QueueUrl from faststream.sqs.producer import SQSFastProducer from faststream.sqs.shared.logging import SQSLoggingMixin from faststream.types import AnyDict, SendableMessage +from faststream.utils import context class SQSBroker( @@ -33,11 +53,34 @@ class SQSBroker( _publishers: Dict[str, Publisher] # type: ignore[assignment] _producer: Optional[SQSFastProducer] - async def start(self) -> None: - pass + def __init__( + self, + url: str = "http://localhost:9324/", + *, + log_fmt: Optional[str] = None, + response_queue: str = "", + protocol: str = "sqs", + **kwargs: Any, + ) -> None: + super().__init__( + url, + log_fmt=log_fmt, + url_=url, + protocol=protocol, + **kwargs, + ) + self._queues = {} + self.response_queue = response_queue + self.response_callbacks = {} - async def _connect(self, **kwargs: Any) -> AioBaseClient: - pass + async def _connect(self, *, url: str, **kwargs: Any) -> AioBaseClient: + session = get_session() + client: AioBaseClient = await session._create_client( + service_name="sqs", endpoint_url=url, **kwargs + ) + context.set_global("client", client) + await client.__aenter__() + return client async def _close( self, @@ -45,26 +88,79 @@ async def _close( exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None, ) -> None: - pass + await super().close(exc_type, exc_val, exec_tb) + for f in self.response_callbacks.values(): + f.cancel() + self.response_callbacks = {} + + for h in self.handlers: + if h.task is not None: + h.task.cancel() + h.task = None + + if self._connection is not None: + await self._connection.__aexit__(None, None, None) + self._connection = None def _process_message( self, - func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], + func: Callable[[StreamMessage[AnyDict]], Awaitable[T_HandlerReturn]], watcher: BaseWatcher, - ) -> Callable[[StreamMessage[MsgType]], Awaitable[WrappedReturn[T_HandlerReturn]],]: - pass + ) -> Callable[[StreamMessage[AnyDict]], Awaitable[WrappedReturn[T_HandlerReturn]],]: + @wraps(func) + async def process_wrapper( + message: StreamMessage[AnyDict], + ) -> WrappedReturn[T_HandlerReturn]: + async with WatcherContext(watcher, message): + r = await self._execute_handler(func, message) + pub_response: Optional[AsyncPublisherProtocol] + if message.reply_to: + pub_response = FakePublisher( + partial(self.publish, subject=message.reply_to) + ) + else: + pub_response = None - async def publish( - self, - message: SendableMessage, - *args: Any, - reply_to: str = "", - rpc: bool = False, - rpc_timeout: Optional[float] = None, - raise_timeout: bool = False, - **kwargs: Any, - ) -> Optional[SendableMessage]: - pass + return r, pub_response + + return process_wrapper + + async def create_queue(self, queue: SQSQueue) -> QueueUrl: + url = self._queues.get(queue.name) + if url is None: # pragma: no branch + url = ( + await self._connection.create_queue( + QueueName=queue.name, + Attributes={ + i: str(j) + for i, j in model_to_dict( + queue, + exclude={"name", "tags"}, + by_alias=True, + exclude_defaults=True, + exclude_unset=True, + ).items() + }, + tags=queue.tags, + ) + ).get("QueueUrl", "") + self._queues[queue.name] = url + return url + + async def start(self) -> None: + context.set_local( + "log_context", + self._get_log_context(None, ""), + ) + + await super().start() + + for handler in self.handlers.values(): # pragma: no branch + c = self._get_log_context(None, handler.queue.name) + self._log(f"`{handler.call_name.__name__}` waiting for messages", extra=c) + + url = await self.create_queue(handler.queue) + handler.task = asyncio.create_task(self._consume(url, handler)) def subscriber( self, @@ -87,9 +183,66 @@ def subscriber( ], HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], ]: + # TODO + pass + + async def publish( + self, + message: SendableMessage, + *args: Any, + reply_to: str = "", + rpc: bool = False, + rpc_timeout: Optional[float] = None, + raise_timeout: bool = False, + **kwargs: Any, + ) -> Optional[SendableMessage]: + # TODO pass def publisher( self, key: Any, publisher: BasePublisher[MsgType] ) -> BasePublisher[MsgType]: + # TODO pass + + async def _consume(self, queue_url: str, handler: Handler) -> NoReturn: + c = self._get_log_context(None, handler.queue.name) + + connected = True + with context.scope("queue_url", queue_url): + while True: + try: + if connected is False: + await self.create_queue(handler.queue) + + r = await self._connection.receive_message( + QueueUrl=queue_url, + **handler.consumer_params, + ) + + except Exception as e: + if connected is True: + self._log(e, logging.WARNING, c, exc_info=e) + self._queues.pop(handler.queue.name) + connected = False + + await anyio.sleep(5) + + else: + if connected is False: + self._log("Connection established", logging.INFO, c) + connected = True + + messages = r.get("Messages", []) + for msg in messages: + try: + await handler.callback(msg, True) + except Exception: + has_trash_messages = True + else: + has_trash_messages = False + + if has_trash_messages is True: + await anyio.sleep( + handler.consumer_params.get("WaitTimeSeconds", 1.0) + ) diff --git a/faststream/sqs/handler.py b/faststream/sqs/handler.py index 148fe33f24..d12bf4ff33 100644 --- a/faststream/sqs/handler.py +++ b/faststream/sqs/handler.py @@ -1,15 +1,11 @@ import asyncio -import logging -from typing import Any, NoReturn, Optional +from typing import Any, Optional -import anyio -from aiobotocore.client import AioBaseClient from typing_extensions import TypeAlias from faststream.broker.handler import AsyncHandler from faststream.sqs.shared.schemas import SQSQueue from faststream.types import AnyDict -from faststream.utils.context import context QueueUrl: TypeAlias = str @@ -19,51 +15,10 @@ class LogicSQSHandler(AsyncHandler[AnyDict]): consumer_params: AnyDict task: Optional["asyncio.Task[Any]"] = None - async def _consume(self, queue_url: str) -> NoReturn: - c = self._get_log_context(None, self.queue.name) - - connected = True - with context.scope("queue_url", queue_url): - while True: - try: - if connected is False: - await self.create_queue(self.queue) - - r = await self._connection.receive_message( - QueueUrl=queue_url, - **self.consumer_params, - ) - - except Exception as e: - if connected is True: - self._log(e, logging.WARNING, c, exc_info=e) - self._queues.pop(self.queue.name) - connected = False - - await anyio.sleep(5) - - else: - if connected is False: - self._log("Connection established", logging.INFO, c) - connected = True - - messages = r.get("Messages", []) - for msg in messages: - try: - await self.callback(msg, True) - except Exception: - has_trash_messages = True - else: - has_trash_messages = False - - if has_trash_messages is True: - await anyio.sleep( - self.consumer_params.get("WaitTimeSeconds", 1.0) - ) - - async def start(self, client: AioBaseClient) -> None: - url = await self.create_queue(self.queue) - self.task = asyncio.create_task(self._consume(url)) + async def start(self) -> None: + # TODO + pass async def close(self) -> None: + # TODO pass