diff --git a/aws/logs_monitoring/forwarder.py b/aws/logs_monitoring/forwarder.py index be18b9d65..5cc1a02f3 100644 --- a/aws/logs_monitoring/forwarder.py +++ b/aws/logs_monitoring/forwarder.py @@ -14,8 +14,8 @@ from logs.datadog_matcher import DatadogMatcher from logs.datadog_scrubber import DatadogScrubber from logs.helpers import add_retry_tag +from retry import create_storage from retry.enums import RetryPrefix -from retry.storage import Storage from settings import ( DD_API_KEY, DD_FORWARD_LOG, @@ -41,7 +41,7 @@ def __init__(self, function_prefix): self.trace_connection = TraceConnection( DD_TRACE_INTAKE_URL, DD_API_KEY, DD_SKIP_SSL_VALIDATION ) - self.storage = Storage(function_prefix) + self.storage = create_storage(function_prefix) def forward(self, logs, metrics, traces): """ diff --git a/aws/logs_monitoring/retry/__init__.py b/aws/logs_monitoring/retry/__init__.py new file mode 100644 index 000000000..eac596517 --- /dev/null +++ b/aws/logs_monitoring/retry/__init__.py @@ -0,0 +1,20 @@ +from retry.base_storage import BaseStorage +from settings import DD_SQS_QUEUE_URL + + +def create_storage(function_prefix) -> BaseStorage: + """Select the appropriate storage backend based on configuration. + + If DD_SQS_QUEUE_URL is set, use SQS. Otherwise, fall back to S3. + The S3 backend may be initialized with an empty bucket name when the + retry feature is disabled (DD_STORE_FAILED_EVENTS=false) — this is + safe because storage methods are only called when retry is enabled. + """ + if DD_SQS_QUEUE_URL: + from retry.sqs_storage import SQSStorage + + return SQSStorage(function_prefix) + + from retry.storage import S3Storage + + return S3Storage(function_prefix) diff --git a/aws/logs_monitoring/retry/base_storage.py b/aws/logs_monitoring/retry/base_storage.py new file mode 100644 index 000000000..9becb0ddb --- /dev/null +++ b/aws/logs_monitoring/retry/base_storage.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + + +class BaseStorage(ABC): + @abstractmethod + def get_data(self, prefix) -> dict: + """Retrieve stored data for a given prefix. Returns {key: data}.""" + ... + + @abstractmethod + def store_data(self, prefix, data) -> None: + """Store data under the given prefix.""" + ... + + @abstractmethod + def delete_data(self, key) -> None: + """Delete stored data by key.""" + ... diff --git a/aws/logs_monitoring/retry/sqs_storage.py b/aws/logs_monitoring/retry/sqs_storage.py new file mode 100644 index 000000000..bd594ce33 --- /dev/null +++ b/aws/logs_monitoring/retry/sqs_storage.py @@ -0,0 +1,170 @@ +import json +import logging +import os + +import boto3 +from botocore.exceptions import ClientError + +from retry.base_storage import BaseStorage +from settings import DD_SQS_QUEUE_URL + +logger = logging.getLogger(__name__) +logger.setLevel(logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper())) + +# SQS max message size is 256KB; use 240KB to leave room for attributes/overhead +SQS_MAX_CHUNK_BYTES = 240 * 1024 +SQS_MAX_MESSAGES_PER_RECEIVE = 10 +SQS_MAX_POLL_ITERATIONS = 10 + + +class SQSStorage(BaseStorage): + def __init__(self, function_prefix): + self.queue_url = DD_SQS_QUEUE_URL + self.sqs_client = boto3.client("sqs") + self.function_prefix = function_prefix + + def get_data(self, prefix): + """Poll SQS for messages matching prefix and function_prefix. + + Returns {receipt_handle: data} for matching messages. + Non-matching messages are released immediately by resetting their + visibility timeout to 0. + """ + key_data = {} + + for _ in range(SQS_MAX_POLL_ITERATIONS): + try: + response = self.sqs_client.receive_message( + QueueUrl=self.queue_url, + MaxNumberOfMessages=SQS_MAX_MESSAGES_PER_RECEIVE, + MessageAttributeNames=["retry_prefix", "function_prefix"], + WaitTimeSeconds=0, + ) + except ClientError as e: + logger.error(f"Failed to receive SQS messages: {e}") + break + + messages = response.get("Messages", []) + if not messages: + break + + for message in messages: + receipt_handle = message["ReceiptHandle"] + msg_retry_prefix = self._get_message_attr(message, "retry_prefix") + msg_function_prefix = self._get_message_attr(message, "function_prefix") + + if ( + msg_retry_prefix != str(prefix) + or msg_function_prefix != self.function_prefix + ): + self._release_message(receipt_handle) + continue + + data = self._deserialize(message["Body"]) + if data is not None: + key_data[receipt_handle] = data + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Found {len(key_data)} SQS retry messages for prefix {prefix}" + ) + + return key_data + + def store_data(self, prefix, data): + """Store data as one or more SQS messages, chunking to stay under the size limit.""" + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Storing retry data to SQS for prefix {prefix}") + + chunks = self._chunk_data(data) + for chunk in chunks: + serialized = self._serialize(chunk) + try: + self.sqs_client.send_message( + QueueUrl=self.queue_url, + MessageBody=serialized, + MessageAttributes={ + "retry_prefix": { + "DataType": "String", + "StringValue": str(prefix), + }, + "function_prefix": { + "DataType": "String", + "StringValue": self.function_prefix, + }, + }, + ) + except ClientError as e: + logger.error(f"Failed to send SQS message for prefix {prefix}: {e}") + + def delete_data(self, key): + """Delete a message by receipt handle. Idempotent — logs and swallows errors.""" + try: + self.sqs_client.delete_message( + QueueUrl=self.queue_url, + ReceiptHandle=key, + ) + except ClientError as e: + logger.error(f"Failed to delete SQS message (receipt={key}): {e}") + + def _release_message(self, receipt_handle): + """Make a non-matching message immediately visible to other consumers.""" + try: + self.sqs_client.change_message_visibility( + QueueUrl=self.queue_url, + ReceiptHandle=receipt_handle, + VisibilityTimeout=0, + ) + except ClientError as e: + logger.error(f"Failed to release SQS message: {e}") + + @staticmethod + def _get_message_attr(message, attr_name): + """Extract a string attribute value from an SQS message.""" + attrs = message.get("MessageAttributes", {}) + return attrs.get(attr_name, {}).get("StringValue") + + def _chunk_data(self, data): + """Split a list of items into chunks that each fit under SQS_MAX_CHUNK_BYTES.""" + if not isinstance(data, list): + return [data] + + chunks = [] + current_chunk = [] + current_size = 2 # account for JSON array brackets "[]" + + for item in data: + item_json = json.dumps(item, ensure_ascii=False) + item_size = len(item_json.encode("UTF-8")) + # +1 for the comma separator between items + separator_size = 1 if current_chunk else 0 + + if current_size + separator_size + item_size > SQS_MAX_CHUNK_BYTES: + if current_chunk: + chunks.append(current_chunk) + if 2 + item_size > SQS_MAX_CHUNK_BYTES: + logger.warning( + f"Single item exceeds SQS message size limit " + f"({item_size} bytes > {SQS_MAX_CHUNK_BYTES} bytes). " + f"SQS send will fail for this chunk." + ) + current_chunk = [item] + current_size = 2 + item_size + else: + current_chunk.append(item) + current_size += separator_size + item_size + + if current_chunk: + chunks.append(current_chunk) + + return chunks or [data] + + def _serialize(self, data): + return json.dumps(data, ensure_ascii=False) + + def _deserialize(self, data): + try: + return json.loads(data) + except (json.JSONDecodeError, TypeError) as e: + logger.error(f"Failed to deserialize SQS message body: {e}") + return None diff --git a/aws/logs_monitoring/retry/storage.py b/aws/logs_monitoring/retry/storage.py index 527ed5c7f..a75ad5d69 100644 --- a/aws/logs_monitoring/retry/storage.py +++ b/aws/logs_monitoring/retry/storage.py @@ -6,13 +6,14 @@ import boto3 from botocore.exceptions import ClientError +from retry.base_storage import BaseStorage from settings import DD_S3_BUCKET_NAME, DD_S3_RETRY_DIRNAME logger = logging.getLogger(__name__) logger.setLevel(logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper())) -class Storage(object): +class S3Storage(BaseStorage): def __init__(self, function_prefix): self.bucket_name = DD_S3_BUCKET_NAME self.s3_client = boto3.client("s3") @@ -81,7 +82,7 @@ def _get_key_prefix(self, retry_prefix): return f"{DD_S3_RETRY_DIRNAME}/{self.function_prefix}/{str(retry_prefix)}/" def _serialize(self, data): - return bytes(json.dumps(data).encode("UTF-8")) + return json.dumps(data).encode("UTF-8") def _deserialize(self, data): return json.loads(data.decode("UTF-8")) diff --git a/aws/logs_monitoring/settings.py b/aws/logs_monitoring/settings.py index e125416c1..a4d2fc7eb 100644 --- a/aws/logs_monitoring/settings.py +++ b/aws/logs_monitoring/settings.py @@ -242,10 +242,10 @@ def is_api_key_valid(): # Check if the API key is the correct number of characters if len(DD_API_KEY) != 32: - raise Exception(f""" - Invalid Datadog API key format. Expected 32 characters, received {len(DD_API_KEY)}. - Verify your API key at https://app.{DD_SITE}/organization-settings/api-keys - """) + raise Exception( + f"Invalid Datadog API key format. Expected 32 characters, received {len(DD_API_KEY)}. " + f"Verify your API key at https://app.{DD_SITE}/organization-settings/api-keys" + ) # Validate the API key logger.debug("Validating the Datadog API key") @@ -379,3 +379,4 @@ def get_enrich_cloudwatch_tags(): DD_S3_RETRY_DIRNAME = "failed_events" DD_RETRY_KEYWORD = "retry" DD_STORE_FAILED_EVENTS = get_env_var("DD_STORE_FAILED_EVENTS", "false", boolean=True) +DD_SQS_QUEUE_URL = get_env_var("DD_SQS_QUEUE_URL", default=None) diff --git a/aws/logs_monitoring/steps/parsing.py b/aws/logs_monitoring/steps/parsing.py index ab512f88a..738fa8971 100644 --- a/aws/logs_monitoring/steps/parsing.py +++ b/aws/logs_monitoring/steps/parsing.py @@ -247,12 +247,6 @@ def normalize_events(events, metadata): def collect_and_count(events): - collected = [] - counter = 0 - for event in events: - counter += 1 - collected.append(event) - - send_event_metric("incoming_events", counter) - + collected = list(events) + send_event_metric("incoming_events", len(collected)) return collected diff --git a/aws/logs_monitoring/template.yaml b/aws/logs_monitoring/template.yaml index a4215365d..a3a3185c1 100644 --- a/aws/logs_monitoring/template.yaml +++ b/aws/logs_monitoring/template.yaml @@ -276,6 +276,10 @@ Parameters: Type: Number Default: 6 Description: Interval in hours for scheduled forwarder invocation (via AWS EventBridge). + DdSqsQueueUrl: + Type: String + Default: "" + Description: URL of an existing SQS queue for failed event storage (alternative to S3). When set, the forwarder uses SQS instead of S3 for retry storage. DdForwarderExistingBucketName: Type: String Default: "" @@ -404,6 +408,11 @@ Conditions: - !Condition CreateS3Bucket - !Not - !Equals [!Ref DdForwarderExistingBucketName, ""] + SetDdSqsQueueUrl: !Not + - !Equals [!Ref DdSqsQueueUrl, ""] + HasStorageBackend: !Or + - !Condition SetForwarderBucket + - !Condition SetDdSqsQueueUrl SetVpcSecurityGroupIds: !Not - !Equals [!Join ["", !Ref VPCSecurityGroupIds], ""] SetVpcSubnetIds: !Not @@ -531,9 +540,13 @@ Resources: - !Ref DdPort - !Ref AWS::NoValue DD_STORE_FAILED_EVENTS: !If - - SetForwarderBucket + - HasStorageBackend - !Ref DdStoreFailedEvents - !Ref AWS::NoValue + DD_SQS_QUEUE_URL: !If + - SetDdSqsQueueUrl + - !Ref DdSqsQueueUrl + - !Ref AWS::NoValue REDACT_IP: !If - SetRedactIp - !Ref RedactIp @@ -770,6 +783,20 @@ Resources: - !Ref SqsQueueArnList - "*" Effect: Allow + - !If + - SetDdSqsQueueUrl # Access SQS queue for failed event storage + - Action: + - sqs:SendMessage + - sqs:ReceiveMessage + - sqs:DeleteMessage + - sqs:ChangeMessageVisibility + Resource: !Sub + - "arn:${AWS::Partition}:sqs:${Region}:${Account}:${QueueName}" + - Region: !Select [1, !Split [".", !Select [2, !Split ["/", !Ref DdSqsQueueUrl]]]] + Account: !Select [3, !Split ["/", !Ref DdSqsQueueUrl]] + QueueName: !Select [4, !Split ["/", !Ref DdSqsQueueUrl]] + Effect: Allow + - !Ref AWS::NoValue Tags: - Value: !FindInMap [Constants, DdForwarder, Version] Key: dd_forwarder_version @@ -1159,6 +1186,7 @@ Metadata: - DdForwarderExistingBucketName - DdForwarderBucketName - DdStoreFailedEvents + - DdSqsQueueUrl - DdLogLevel ParameterLabels: DdApiKey: diff --git a/aws/logs_monitoring/tests/test_s3_storage.py b/aws/logs_monitoring/tests/test_s3_storage.py new file mode 100644 index 000000000..8f6e0820d --- /dev/null +++ b/aws/logs_monitoring/tests/test_s3_storage.py @@ -0,0 +1,90 @@ +import json +import unittest +from unittest.mock import MagicMock, patch + +from botocore.exceptions import ClientError + +from retry.storage import S3Storage + + +class TestS3Storage(unittest.TestCase): + def setUp(self): + self.mock_s3 = MagicMock() + with patch("retry.storage.boto3") as mock_boto3: + mock_boto3.client.return_value = self.mock_s3 + with patch("retry.storage.DD_S3_BUCKET_NAME", "test-bucket"): + self.storage = S3Storage("test_function_prefix") + + def test_store_data_puts_object(self): + self.storage.store_data("logs", [{"message": "hello"}]) + self.mock_s3.put_object.assert_called_once() + call_kwargs = self.mock_s3.put_object.call_args[1] + self.assertEqual(call_kwargs["Bucket"], "test-bucket") + self.assertIn("failed_events/test_function_prefix/logs/", call_kwargs["Key"]) + self.assertEqual( + json.loads(call_kwargs["Body"].decode("UTF-8")), [{"message": "hello"}] + ) + + def test_store_data_handles_client_error(self): + self.mock_s3.put_object.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "PutObject" + ) + # Should not raise + self.storage.store_data("logs", [{"message": "hello"}]) + + def test_get_data_returns_data_for_keys(self): + self.mock_s3.list_objects_v2.return_value = { + "Contents": [{"Key": "failed_events/test_function_prefix/logs/123"}] + } + body_mock = MagicMock() + body_mock.read.return_value = json.dumps([{"message": "hello"}]).encode("UTF-8") + self.mock_s3.get_object.return_value = {"Body": body_mock} + + result = self.storage.get_data("logs") + self.assertEqual( + result, + {"failed_events/test_function_prefix/logs/123": [{"message": "hello"}]}, + ) + + def test_get_data_handles_empty_bucket(self): + self.mock_s3.list_objects_v2.return_value = {} + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_get_data_handles_list_error(self): + self.mock_s3.list_objects_v2.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "ListObjectsV2" + ) + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_get_data_handles_fetch_error(self): + self.mock_s3.list_objects_v2.return_value = { + "Contents": [{"Key": "failed_events/test_function_prefix/logs/123"}] + } + self.mock_s3.get_object.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "GetObject" + ) + result = self.storage.get_data("logs") + self.assertEqual(result, {"failed_events/test_function_prefix/logs/123": None}) + + def test_delete_data_deletes_object(self): + self.storage.delete_data("failed_events/test_function_prefix/logs/123") + self.mock_s3.delete_object.assert_called_once_with( + Bucket="test-bucket", Key="failed_events/test_function_prefix/logs/123" + ) + + def test_delete_data_handles_client_error(self): + self.mock_s3.delete_object.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "DeleteObject" + ) + # Should not raise + self.storage.delete_data("some_key") + + def test_get_key_prefix(self): + prefix = self.storage._get_key_prefix("logs") + self.assertEqual(prefix, "failed_events/test_function_prefix/logs/") + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/logs_monitoring/tests/test_sqs_storage.py b/aws/logs_monitoring/tests/test_sqs_storage.py new file mode 100644 index 000000000..e7d22546f --- /dev/null +++ b/aws/logs_monitoring/tests/test_sqs_storage.py @@ -0,0 +1,194 @@ +import json +import unittest +from unittest.mock import MagicMock, patch + +from botocore.exceptions import ClientError + +from retry.sqs_storage import SQSStorage, SQS_MAX_CHUNK_BYTES + + +class TestSQSStorage(unittest.TestCase): + def setUp(self): + self.mock_sqs = MagicMock() + with patch("retry.sqs_storage.boto3") as mock_boto3: + mock_boto3.client.return_value = self.mock_sqs + with patch( + "retry.sqs_storage.DD_SQS_QUEUE_URL", + "https://sqs.us-east-1.amazonaws.com/123456789012/my-queue", + ): + self.storage = SQSStorage("test_function_prefix") + + def test_store_data_sends_message_with_attributes(self): + data = [{"message": "hello"}] + self.storage.store_data("logs", data) + + self.mock_sqs.send_message.assert_called_once() + call_kwargs = self.mock_sqs.send_message.call_args[1] + self.assertEqual( + call_kwargs["QueueUrl"], + "https://sqs.us-east-1.amazonaws.com/123456789012/my-queue", + ) + self.assertEqual( + call_kwargs["MessageAttributes"]["retry_prefix"]["StringValue"], "logs" + ) + self.assertEqual( + call_kwargs["MessageAttributes"]["function_prefix"]["StringValue"], + "test_function_prefix", + ) + self.assertEqual(json.loads(call_kwargs["MessageBody"]), data) + + def test_store_data_chunks_large_data(self): + # Create two items that each fit individually but together exceed 240KB + large_item = {"message": "x" * (SQS_MAX_CHUNK_BYTES - 50)} + small_item = {"message": "y" * 100} + data = [large_item, small_item] + + self.storage.store_data("logs", data) + + # Should send 2 messages (items can't fit in one chunk) + self.assertEqual(self.mock_sqs.send_message.call_count, 2) + + def test_store_data_handles_client_error(self): + self.mock_sqs.send_message.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "SendMessage" + ) + # Should not raise + self.storage.store_data("logs", [{"message": "hello"}]) + + def test_get_data_returns_matching_messages(self): + self.mock_sqs.receive_message.side_effect = [ + { + "Messages": [ + { + "ReceiptHandle": "handle1", + "Body": json.dumps([{"message": "hello"}]), + "MessageAttributes": { + "retry_prefix": {"StringValue": "logs"}, + "function_prefix": {"StringValue": "test_function_prefix"}, + }, + } + ] + }, + {"Messages": []}, + ] + + result = self.storage.get_data("logs") + self.assertEqual(result, {"handle1": [{"message": "hello"}]}) + + def test_get_data_releases_non_matching_messages(self): + self.mock_sqs.receive_message.side_effect = [ + { + "Messages": [ + { + "ReceiptHandle": "handle_other", + "Body": json.dumps([{"message": "other"}]), + "MessageAttributes": { + "retry_prefix": {"StringValue": "metrics"}, + "function_prefix": {"StringValue": "other_function"}, + }, + } + ] + }, + {"Messages": []}, + ] + + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + self.mock_sqs.change_message_visibility.assert_called_once_with( + QueueUrl="https://sqs.us-east-1.amazonaws.com/123456789012/my-queue", + ReceiptHandle="handle_other", + VisibilityTimeout=0, + ) + + def test_get_data_handles_empty_queue(self): + self.mock_sqs.receive_message.return_value = {"Messages": []} + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_get_data_handles_no_messages_key(self): + self.mock_sqs.receive_message.return_value = {} + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_get_data_handles_client_error(self): + self.mock_sqs.receive_message.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "Error"}}, "ReceiveMessage" + ) + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_get_data_skips_invalid_json(self): + self.mock_sqs.receive_message.side_effect = [ + { + "Messages": [ + { + "ReceiptHandle": "handle1", + "Body": "not valid json{{{", + "MessageAttributes": { + "retry_prefix": {"StringValue": "logs"}, + "function_prefix": {"StringValue": "test_function_prefix"}, + }, + } + ] + }, + {"Messages": []}, + ] + + result = self.storage.get_data("logs") + self.assertEqual(result, {}) + + def test_delete_data_calls_delete_message(self): + self.storage.delete_data("receipt_handle_123") + self.mock_sqs.delete_message.assert_called_once_with( + QueueUrl="https://sqs.us-east-1.amazonaws.com/123456789012/my-queue", + ReceiptHandle="receipt_handle_123", + ) + + def test_delete_data_is_idempotent(self): + self.mock_sqs.delete_message.side_effect = ClientError( + {"Error": {"Code": "ReceiptHandleIsInvalid", "Message": "Error"}}, + "DeleteMessage", + ) + # Should not raise + self.storage.delete_data("already_deleted_handle") + + def test_chunk_data_single_small_list(self): + data = [{"a": 1}, {"b": 2}] + chunks = self.storage._chunk_data(data) + self.assertEqual(len(chunks), 1) + self.assertEqual(chunks[0], data) + + def test_chunk_data_non_list(self): + data = {"key": "value"} + chunks = self.storage._chunk_data(data) + self.assertEqual(chunks, [data]) + + def test_chunk_data_empty_list(self): + chunks = self.storage._chunk_data([]) + self.assertEqual(chunks, [[]]) + + def test_get_data_polls_multiple_iterations(self): + """Verify that get_data keeps polling until an empty response.""" + self.mock_sqs.receive_message.side_effect = [ + { + "Messages": [ + { + "ReceiptHandle": f"handle_{i}", + "Body": json.dumps([{"msg": i}]), + "MessageAttributes": { + "retry_prefix": {"StringValue": "logs"}, + "function_prefix": {"StringValue": "test_function_prefix"}, + }, + } + ] + } + for i in range(3) + ] + [{"Messages": []}] + + result = self.storage.get_data("logs") + self.assertEqual(len(result), 3) + self.assertEqual(self.mock_sqs.receive_message.call_count, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/logs_monitoring/tests/test_storage_factory.py b/aws/logs_monitoring/tests/test_storage_factory.py new file mode 100644 index 000000000..e85389267 --- /dev/null +++ b/aws/logs_monitoring/tests/test_storage_factory.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import patch + +from retry.storage import S3Storage +from retry.sqs_storage import SQSStorage + + +class TestCreateStorage(unittest.TestCase): + @patch("retry.sqs_storage.boto3") + @patch("retry.DD_SQS_QUEUE_URL", "https://sqs.us-east-1.amazonaws.com/123/queue") + @patch( + "retry.sqs_storage.DD_SQS_QUEUE_URL", + "https://sqs.us-east-1.amazonaws.com/123/queue", + ) + def test_sqs_backend_when_queue_url_set(self, mock_boto3): + from retry import create_storage + + storage = create_storage("func_prefix") + self.assertIsInstance(storage, SQSStorage) + + @patch("retry.storage.boto3") + @patch("retry.DD_SQS_QUEUE_URL", None) + @patch("retry.storage.DD_S3_BUCKET_NAME", "my-bucket") + def test_s3_backend_when_no_queue_url(self, mock_boto3): + from retry import create_storage + + storage = create_storage("func_prefix") + self.assertIsInstance(storage, S3Storage) + + @patch("retry.storage.boto3") + @patch("retry.DD_SQS_QUEUE_URL", None) + def test_falls_back_to_s3_when_no_backend_configured(self, mock_boto3): + """When no SQS queue is configured, always fall back to S3Storage. + + This preserves backward compatibility: S3Storage with an empty bucket + name is safe as long as DD_STORE_FAILED_EVENTS is false (the default). + """ + from retry import create_storage + + storage = create_storage("func_prefix") + self.assertIsInstance(storage, S3Storage) + + @patch("retry.sqs_storage.boto3") + @patch("retry.DD_SQS_QUEUE_URL", "https://sqs.us-east-1.amazonaws.com/123/queue") + @patch( + "retry.sqs_storage.DD_SQS_QUEUE_URL", + "https://sqs.us-east-1.amazonaws.com/123/queue", + ) + def test_sqs_takes_priority_over_s3(self, mock_boto3): + """SQS is selected even when S3 bucket is not set.""" + from retry import create_storage + + storage = create_storage("func_prefix") + self.assertIsInstance(storage, SQSStorage) + + +if __name__ == "__main__": + unittest.main()