diff --git a/faststream/sqs/__init__.py b/faststream/sqs/__init__.py new file mode 100644 index 0000000000..16ca43cc96 --- /dev/null +++ b/faststream/sqs/__init__.py @@ -0,0 +1,24 @@ +from faststream.broker.test import TestApp +from faststream.sqs.annotations import SQSBroker, SQSMessage +from faststream.sqs.router import SQSRouter +from faststream.sqs.shared.router import SQSRoute +from faststream.sqs.shared.schemas import ( + FifoQueue, + RedriveAllowPolicy, + RedrivePolicy, + SQSQueue, +) +from faststream.sqs.test import TestSQSBroker + +__all__ = ( + "FifoQueue", + "RedriveAllowPolicy", + "RedrivePolicy", + "SQSBroker", + "SQSMessage", + "SQSQueue", + "SQSRouter", + "SQSRoute", + "TestApp", + "TestSQSBroker", +) diff --git a/faststream/sqs/annotations.py b/faststream/sqs/annotations.py new file mode 100644 index 0000000000..c6c754b5a2 --- /dev/null +++ b/faststream/sqs/annotations.py @@ -0,0 +1,25 @@ +from aiobotocore.client import AioBaseClient + +from faststream._compat import Annotated +from faststream.annotations import ContextRepo, Logger, NoCast +from faststream.sqs.broker import SQSBroker as SB # NOQA +from faststream.sqs.message import SQSMessage as SM # NOQA +from faststream.sqs.producer import SQSFastProducer +from faststream.utils.context import Context + +__all__ = ( + "Logger", + "ContextRepo", + "NoCast", + "SQSBroker", + "SQSMessage", + "SQSProducer", + "client", + "queue_url", +) + +SQSBroker = Annotated[SB, Context("broker")] +SQSMessage = Annotated[SM, Context("message")] +SQSProducer = Annotated[SQSFastProducer, Context("broker._producer")] +client = Annotated[AioBaseClient, Context("client")] +queue_url = Annotated[str, Context("queue_url")] diff --git a/faststream/sqs/asyncapi.py b/faststream/sqs/asyncapi.py new file mode 100644 index 0000000000..8b9fb0151e --- /dev/null +++ b/faststream/sqs/asyncapi.py @@ -0,0 +1,44 @@ +from typing import Dict + +from faststream._compat import model_to_dict +from faststream.asyncapi.schema import ( + Channel, + ChannelBinding, + CorrelationId, + Message, + Operation, +) +from faststream.asyncapi.schema.bindings import sqs +from faststream.asyncapi.utils import resolve_payloads +from faststream.sqs.handler import LogicSQSHandler + + +class Handler(LogicSQSHandler): + def schema(self) -> Dict[str, Channel]: + payloads = self.get_payloads() + handler_name = self._title or f"{self.queue.name}:{self.call_name}" + + return { + handler_name: Channel( + description=self.description, + subscribe=Operation( + message=Message( + title=f"{handler_name}:Message", + correlationId=CorrelationId( + location="$message.header#/correlation_id" + ), + payload=resolve_payloads(payloads), + ), + ), + bindings=ChannelBinding( + sqs=sqs.ChannelBinding( + queue=model_to_dict(self.queue, include={"name", "fifo"}), + ) + ), + ), + } + + +class Publisher: + # TODO + pass diff --git a/faststream/sqs/broker.py b/faststream/sqs/broker.py new file mode 100644 index 0000000000..67317a5509 --- /dev/null +++ b/faststream/sqs/broker.py @@ -0,0 +1,248 @@ +import asyncio +import logging +from functools import partial, wraps +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Dict, + NoReturn, + Optional, + Sequence, + Type, + Union, +) + +import anyio +from aiobotocore.client import AioBaseClient +from aiobotocore.session import get_session +from fast_depends.dependencies import Depends + +from faststream import BaseMiddleware +from faststream._compat import model_to_dict +from faststream.broker.core.asyncronous import BrokerAsyncUsecase, default_filter +from faststream.broker.message import StreamMessage +from faststream.broker.publisher import BasePublisher +from faststream.broker.push_back_watcher import BaseWatcher, WatcherContext +from faststream.broker.types import ( + AsyncPublisherProtocol, + CustomDecoder, + CustomParser, + Filter, + MsgType, + P_HandlerParams, + T_HandlerReturn, + WrappedReturn, +) +from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper +from faststream.sqs import SQSQueue +from faststream.sqs.asyncapi import Handler, Publisher +from faststream.sqs.handler import QueueUrl +from faststream.sqs.producer import SQSFastProducer +from faststream.sqs.shared.logging import SQSLoggingMixin +from faststream.types import AnyDict, SendableMessage +from faststream.utils import context + + +class SQSBroker( + SQSLoggingMixin, + BrokerAsyncUsecase[AnyDict, AioBaseClient], +): + handlers: Dict[str, Handler] # type: ignore[assignment] + _publishers: Dict[str, Publisher] # type: ignore[assignment] + _producer: Optional[SQSFastProducer] + + def __init__( + self, + url: str = "http://localhost:9324/", + *, + log_fmt: Optional[str] = None, + response_queue: str = "", + protocol: str = "sqs", + **kwargs: Any, + ) -> None: + super().__init__( + url, + log_fmt=log_fmt, + url_=url, + protocol=protocol, + **kwargs, + ) + self._queues = {} + self.response_queue = response_queue + self.response_callbacks = {} + + async def _connect(self, *, url: str, **kwargs: Any) -> AioBaseClient: + session = get_session() + client: AioBaseClient = await session._create_client( + service_name="sqs", endpoint_url=url, **kwargs + ) + context.set_global("client", client) + await client.__aenter__() + return client + + async def _close( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exec_tb: Optional[TracebackType] = None, + ) -> None: + await super().close(exc_type, exc_val, exec_tb) + for f in self.response_callbacks.values(): + f.cancel() + self.response_callbacks = {} + + for h in self.handlers: + if h.task is not None: + h.task.cancel() + h.task = None + + if self._connection is not None: + await self._connection.__aexit__(None, None, None) + self._connection = None + + def _process_message( + self, + func: Callable[[StreamMessage[AnyDict]], Awaitable[T_HandlerReturn]], + watcher: BaseWatcher, + ) -> Callable[[StreamMessage[AnyDict]], Awaitable[WrappedReturn[T_HandlerReturn]],]: + @wraps(func) + async def process_wrapper( + message: StreamMessage[AnyDict], + ) -> WrappedReturn[T_HandlerReturn]: + async with WatcherContext(watcher, message): + r = await self._execute_handler(func, message) + pub_response: Optional[AsyncPublisherProtocol] + if message.reply_to: + pub_response = FakePublisher( + partial(self.publish, subject=message.reply_to) + ) + else: + pub_response = None + + return r, pub_response + + return process_wrapper + + async def create_queue(self, queue: SQSQueue) -> QueueUrl: + url = self._queues.get(queue.name) + if url is None: # pragma: no branch + url = ( + await self._connection.create_queue( + QueueName=queue.name, + Attributes={ + i: str(j) + for i, j in model_to_dict( + queue, + exclude={"name", "tags"}, + by_alias=True, + exclude_defaults=True, + exclude_unset=True, + ).items() + }, + tags=queue.tags, + ) + ).get("QueueUrl", "") + self._queues[queue.name] = url + return url + + async def start(self) -> None: + context.set_local( + "log_context", + self._get_log_context(None, ""), + ) + + await super().start() + + for handler in self.handlers.values(): # pragma: no branch + c = self._get_log_context(None, handler.queue.name) + self._log(f"`{handler.call_name.__name__}` waiting for messages", extra=c) + + url = await self.create_queue(handler.queue) + handler.task = asyncio.create_task(self._consume(url, handler)) + + def subscriber( + self, + *broker_args: Any, + retry: Union[bool, int] = False, + dependencies: Sequence[Depends] = (), + decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, + parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, + middlewares: Optional[Sequence[Callable[[MsgType], BaseMiddleware]]] = None, + filter: Filter[StreamMessage[MsgType]] = default_filter, + _raw: bool = False, + _get_dependant: Optional[Any] = None, + **broker_kwargs: Any, + ) -> Callable[ + [ + Union[ + Callable[P_HandlerParams, T_HandlerReturn], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ] + ], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ]: + # TODO + pass + + async def publish( + self, + message: SendableMessage, + *args: Any, + reply_to: str = "", + rpc: bool = False, + rpc_timeout: Optional[float] = None, + raise_timeout: bool = False, + **kwargs: Any, + ) -> Optional[SendableMessage]: + # TODO + pass + + def publisher( + self, key: Any, publisher: BasePublisher[MsgType] + ) -> BasePublisher[MsgType]: + # TODO + pass + + async def _consume(self, queue_url: str, handler: Handler) -> NoReturn: + c = self._get_log_context(None, handler.queue.name) + + connected = True + with context.scope("queue_url", queue_url): + while True: + try: + if connected is False: + await self.create_queue(handler.queue) + + r = await self._connection.receive_message( + QueueUrl=queue_url, + **handler.consumer_params, + ) + + except Exception as e: + if connected is True: + self._log(e, logging.WARNING, c, exc_info=e) + self._queues.pop(handler.queue.name) + connected = False + + await anyio.sleep(5) + + else: + if connected is False: + self._log("Connection established", logging.INFO, c) + connected = True + + messages = r.get("Messages", []) + for msg in messages: + try: + await handler.callback(msg, True) + except Exception: + has_trash_messages = True + else: + has_trash_messages = False + + if has_trash_messages is True: + await anyio.sleep( + handler.consumer_params.get("WaitTimeSeconds", 1.0) + ) diff --git a/faststream/sqs/handler.py b/faststream/sqs/handler.py new file mode 100644 index 0000000000..d12bf4ff33 --- /dev/null +++ b/faststream/sqs/handler.py @@ -0,0 +1,24 @@ +import asyncio +from typing import Any, Optional + +from typing_extensions import TypeAlias + +from faststream.broker.handler import AsyncHandler +from faststream.sqs.shared.schemas import SQSQueue +from faststream.types import AnyDict + +QueueUrl: TypeAlias = str + + +class LogicSQSHandler(AsyncHandler[AnyDict]): + queue: SQSQueue + consumer_params: AnyDict + task: Optional["asyncio.Task[Any]"] = None + + async def start(self) -> None: + # TODO + pass + + async def close(self) -> None: + # TODO + pass diff --git a/faststream/sqs/message.py b/faststream/sqs/message.py new file mode 100644 index 0000000000..4c828e74c9 --- /dev/null +++ b/faststream/sqs/message.py @@ -0,0 +1,26 @@ +from typing import Any + +from faststream.broker.message import StreamMessage +from faststream.types import AnyDict + + +class SQSMessage(StreamMessage[AnyDict]): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.commited = False + + async def ack(self, **kwargs: Any) -> None: + # TODO: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html + self.commited = True + + async def nack(self, **kwargs: Any) -> None: + # TODO + self.commited = True + + async def reject(self, **kwargs: Any) -> None: + # TODO + self.commited = True diff --git a/faststream/sqs/producer.py b/faststream/sqs/producer.py new file mode 100644 index 0000000000..1831d34390 --- /dev/null +++ b/faststream/sqs/producer.py @@ -0,0 +1,3 @@ +class SQSFastProducer: + # TODO + pass diff --git a/faststream/sqs/router.py b/faststream/sqs/router.py new file mode 100644 index 0000000000..10cc5498e8 --- /dev/null +++ b/faststream/sqs/router.py @@ -0,0 +1,6 @@ +from faststream.sqs.shared.router import SQSRouter as BaseRouter + + +class SQSRouter(BaseRouter): + # TODO + pass diff --git a/faststream/sqs/shared/__init__.py b/faststream/sqs/shared/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/faststream/sqs/shared/logging.py b/faststream/sqs/shared/logging.py new file mode 100644 index 0000000000..bf306daf2b --- /dev/null +++ b/faststream/sqs/shared/logging.py @@ -0,0 +1,56 @@ +import logging +from typing import Any, Optional + +from faststream._compat import override +from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.message import StreamMessage +from faststream.log import access_logger +from faststream.types import AnyDict + + +class SQSLoggingMixin(LoggingMixin): + _max_queue_len: int + + def __init__( + self, + *args: Any, + logger: Optional[logging.Logger] = access_logger, + log_level: int = logging.INFO, + log_fmt: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + **kwargs, + ) + self._max_queue_len = 4 + + @override + def _get_log_context( # type: ignore[override] + self, + message: Optional[StreamMessage[Any]], + queue: str = "", + ) -> AnyDict: + return { + "queue": queue, + **super()._get_log_context(message), + } + + @property + def fmt(self) -> str: + return self._fmt or ( + "%(asctime)s %(levelname)s - " + f"%(queue)-{self._max_queue_len}s | " + "%(message_id)-10s " + "- %(message)s" + ) + + def _setup_log_context( + self, + queue: Optional[str] = None, + ) -> None: + if queue is not None: + self._max_queue_len = max((self._max_queue_len, len(queue))) diff --git a/faststream/sqs/shared/router.py b/faststream/sqs/shared/router.py new file mode 100644 index 0000000000..8805a4fd6b --- /dev/null +++ b/faststream/sqs/shared/router.py @@ -0,0 +1,40 @@ +from typing import Callable, Union + +from faststream._compat import model_copy, override +from faststream.broker.router import BrokerRoute as SQSRoute +from faststream.broker.router import BrokerRouter +from faststream.broker.types import P_HandlerParams, T_HandlerReturn +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.sqs.shared.schemas import SQSQueue +from faststream.types import Any, AnyDict, SendableMessage, Sequence + + +class SQSRouter(BrokerRouter[AnyDict]): + def __init__( + self, + prefix: str = "", + handlers: Sequence[SQSRoute[AnyDict, SendableMessage]] = (), + **kwargs: Any, + ): + for h in handlers: + if (q := h.kwargs.pop("queue", None)) is None: + q, h.args = h.args[0], h.args[1:] + queue = SQSQueue.validate(q) + new_q = model_copy(queue, update={"name": prefix + queue.name}) + h.args = (new_q, *h.args) + + super().__init__(prefix, handlers, **kwargs) + + @override + def subscriber( # type: ignore[override] + self, + queue: Union[str, SQSQueue], + *args: Any, + **kwargs: AnyDict, + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[AnyDict, P_HandlerParams, T_HandlerReturn], + ]: + q = SQSQueue.validate(queue) + new_q = model_copy(q, update={"name": self.prefix + q.name}) + return self._wrap_subscriber(new_q, *args, **kwargs) diff --git a/faststream/sqs/shared/schemas.py b/faststream/sqs/shared/schemas.py new file mode 100644 index 0000000000..ce2b6821b3 --- /dev/null +++ b/faststream/sqs/shared/schemas.py @@ -0,0 +1,198 @@ +from typing import Any, Dict, Optional, Sequence + +from pydantic import BaseModel, Field, PositiveInt +from typing_extensions import Literal + +from faststream._compat import PYDANTIC_V2 +from faststream.broker.schemas import NameRequired + + +class RedrivePolicy(BaseModel): + """SQS Queue RedrivePolicy attribute details""" + + dead_letter_target: str = Field( + default="", + alias="deadLetterTargetArn", + ) + max_receive_count: PositiveInt = Field( + default=10, + alias="deadLetterTargetArn", + ) + + +class RedriveAllowPolicy(BaseModel): + """SQS Queue RedriveAllowPolicy attribute details""" + + redrive_permission: Literal["allowAll", "denyAll", "byQueue"] = Field( + default="allowAll", + alias="redrivePermission", + ) + source_queue_arns: Sequence[str] = Field( + default_factory=tuple, + alias="sourceQueueArns", + max_length=10, + ) + + +class SQSQueue(NameRequired): + """SQS Basic Queue initialization attributes""" + + fifo: bool = Field( + default=False, + alias="FifoQueue", + ) + + delay_seconds: int = Field( + default=0, + # alias="DelaySeconds", + ge=0, + le=900, + ) + max_message_size: int = Field( + default=262_144, + alias="MaximumMessageSize", + ge=1024, + le=262_144, + ) + retention_period_sec: int = Field( + 345_600, + alias="MessageRetentionPeriod", + ge=60, + le=1_209_600, + ) + receive_wait_time_sec: int = Field( + default=0, + alias="ReceiveMessageWaitTimeSeconds", + ge=0, + le=20, + ) + visibility_timeout_sec: int = Field( + default=30, + alias="VisibilityTimeout", + ge=0, + le=43_200, + ) + redrive_policy: RedrivePolicy = Field( + default_factory=RedrivePolicy, + alias="RedrivePolicy", + ) + redrive_allow_policy: RedriveAllowPolicy = Field( + default_factory=RedrivePolicy, + alias="RedriveAllowPolicy", + ) + + kms_master_key_id: str = Field(default="", alias="KmsMasterKeyId") + kms_data_key_reuse_period_sec: int = Field( + default=300, + alias="KmsDataKeyReusePeriodSeconds", + ge=60, + le=86_400, + ) + sse_enabled: bool = Field( + default=False, + alias="SqsManagedSseEnabled", + ) + tags: Dict[str, str] = Field( + default_factory=dict, + ) + + def __init__( + self, + name: str, + fifo: bool = False, + delay_seconds: int = 0, + max_message_size: int = 262_144, + visibility_timeout_sec: int = 0, + receive_wait_time_sec: int = 0, + retention_period_sec: int = 345_600, + redrive_policy: Optional[RedrivePolicy] = None, + redrive_allow_policy: Optional[RedriveAllowPolicy] = None, + kms_master_key_id: str = "", + kms_data_key_reuse_period_sec: int = 300, + sse_enabled: bool = False, + tags: Optional[Dict[str, str]] = None, + **kwargs: Any, + ): + super().__init__( + name=name, + fifo=fifo, + visibility_timeout_sec=visibility_timeout_sec, + receive_wait_time_sec=receive_wait_time_sec, + retention_period_sec=retention_period_sec, + max_message_size=max_message_size, + delay_seconds=delay_seconds, + redrive_policy=redrive_policy or RedrivePolicy(), + redrive_allow_policy=redrive_allow_policy or RedriveAllowPolicy(), + kms_master_key_id=kms_master_key_id, + kms_data_key_reuse_period_sec=kms_data_key_reuse_period_sec, + sse_enabled=sse_enabled, + tags=tags or {}, + **kwargs, + ) + + if PYDANTIC_V2: + model_config = {"arbitrary_types_allowed": True} + else: + + class Config: + arbitrary_types_allowed = True + + +class FifoQueue(SQSQueue): + """SQS FIFO Queue initialization attributes""" + + fifo: bool = Field( + default=True, + alias="FifoQueue", + ) + content_based_deduplication: bool = Field( + default=True, + alias="ContentBasedDeduplication", + ) + deduplication_scope: Optional[Literal["messageGroup", "queue"]] = Field( + default=None, + alias="DeduplicationScope", + ) + + # TODO: pydantic validation and test + # allow perMessageGroup only for messageGroup deduplication_scope + throughput_limit: Optional[Literal["perMessageGroup", "perQueue"]] = Field( + default=None, + alias="FifoThroughputLimit", + ) + + def __init__( + self, + name: str, + fifo: bool = True, + delay_seconds: int = 0, + max_message_size: int = 262_144, + visibility_timeout_sec: int = 0, + receive_wait_time_sec: int = 0, + retention_period_sec: int = 345_600, + content_based_deduplication: bool = True, + deduplication_scope: Optional[Literal["messageGroup", "queue"]] = None, + throughput_limit: Optional[Literal["perMessageGroup", "perQueue"]] = None, + redrive_policy: Optional[RedrivePolicy] = None, + redrive_allow_policy: Optional[RedriveAllowPolicy] = None, + kms_data_key_reuse_period_sec: int = 300, + sse_enabled: bool = False, + tags: Optional[Dict[str, str]] = None, + ): + super().__init__( + name=name, + fifo=fifo, + visibility_timeout_sec=visibility_timeout_sec, + receive_wait_time_sec=receive_wait_time_sec, + content_based_deduplication=content_based_deduplication, + retention_period_sec=retention_period_sec, + max_message_size=max_message_size, + delay_seconds=delay_seconds, + redrive_policy=redrive_policy or RedrivePolicy(), + redrive_allow_policy=redrive_allow_policy or RedriveAllowPolicy(), + kms_data_key_reuse_period_sec=kms_data_key_reuse_period_sec, + sse_enabled=sse_enabled, + deduplication_scope=deduplication_scope, + throughput_limit=throughput_limit, + tags=tags or {}, + ) diff --git a/faststream/sqs/test.py b/faststream/sqs/test.py new file mode 100644 index 0000000000..17a5bddcde --- /dev/null +++ b/faststream/sqs/test.py @@ -0,0 +1,7 @@ +from faststream.broker.test import TestBroker +from faststream.sqs.broker import SQSBroker + + +class TestSQSBroker(TestBroker[SQSBroker]): + # TODO + pass