Skip to content

Commit b47586e

Browse files
authored
Add async support for Amazon SNS Notifier (apache#56133)
* Add async support for Amazon SNS Notifier
1 parent afa7345 commit b47586e

File tree

4 files changed

+227
-96
lines changed

4 files changed

+227
-96
lines changed

providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323

2424
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
25+
from airflow.utils.helpers import prune_dict
2526

2627

2728
def _get_message_attribute(o):
@@ -38,6 +39,33 @@ def _get_message_attribute(o):
3839
)
3940

4041

42+
def _build_publish_kwargs(
43+
target_arn: str,
44+
message: str,
45+
subject: str | None = None,
46+
message_attributes: dict | None = None,
47+
message_deduplication_id: str | None = None,
48+
message_group_id: str | None = None,
49+
) -> dict[str, str | dict]:
50+
publish_kwargs: dict[str, str | dict] = prune_dict(
51+
{
52+
"TargetArn": target_arn,
53+
"MessageStructure": "json",
54+
"Message": json.dumps({"default": message}),
55+
"Subject": subject,
56+
"MessageDeduplicationId": message_deduplication_id,
57+
"MessageGroupId": message_group_id,
58+
}
59+
)
60+
61+
if message_attributes:
62+
publish_kwargs["MessageAttributes"] = {
63+
key: _get_message_attribute(val) for key, val in message_attributes.items()
64+
}
65+
66+
return publish_kwargs
67+
68+
4169
class SnsHook(AwsBaseHook):
4270
"""
4371
Interact with Amazon Simple Notification Service.
@@ -84,22 +112,50 @@ def publish_to_target(
84112
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
85113
This parameter applies only to FIFO (first-in-first-out) topics.
86114
"""
87-
publish_kwargs: dict[str, str | dict] = {
88-
"TargetArn": target_arn,
89-
"MessageStructure": "json",
90-
"Message": json.dumps({"default": message}),
91-
}
115+
return self.get_conn().publish(
116+
**_build_publish_kwargs(
117+
target_arn, message, subject, message_attributes, message_deduplication_id, message_group_id
118+
)
119+
)
92120

93-
# Construct args this way because boto3 distinguishes from missing args and those set to None
94-
if subject:
95-
publish_kwargs["Subject"] = subject
96-
if message_deduplication_id:
97-
publish_kwargs["MessageDeduplicationId"] = message_deduplication_id
98-
if message_group_id:
99-
publish_kwargs["MessageGroupId"] = message_group_id
100-
if message_attributes:
101-
publish_kwargs["MessageAttributes"] = {
102-
key: _get_message_attribute(val) for key, val in message_attributes.items()
103-
}
104-
105-
return self.get_conn().publish(**publish_kwargs)
121+
async def apublish_to_target(
122+
self,
123+
target_arn: str,
124+
message: str,
125+
subject: str | None = None,
126+
message_attributes: dict | None = None,
127+
message_deduplication_id: str | None = None,
128+
message_group_id: str | None = None,
129+
):
130+
"""
131+
Publish a message to a SNS topic or an endpoint.
132+
133+
.. seealso::
134+
- :external+boto3:py:meth:`SNS.Client.publish`
135+
136+
:param target_arn: either a TopicArn or an EndpointArn
137+
:param message: the default message you want to send
138+
:param subject: subject of message
139+
:param message_attributes: additional attributes to publish for message filtering. This should be
140+
a flat dict; the DataType to be sent depends on the type of the value:
141+
142+
- bytes = Binary
143+
- str = String
144+
- int, float = Number
145+
- iterable = String.Array
146+
:param message_deduplication_id: Every message must have a unique message_deduplication_id.
147+
This parameter applies only to FIFO (first-in-first-out) topics.
148+
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
149+
This parameter applies only to FIFO (first-in-first-out) topics.
150+
"""
151+
async with await self.get_async_conn() as async_client:
152+
return await async_client.publish(
153+
**_build_publish_kwargs(
154+
target_arn,
155+
message,
156+
subject,
157+
message_attributes,
158+
message_deduplication_id,
159+
message_group_id,
160+
)
161+
)

providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from functools import cached_property
2222

2323
from airflow.providers.amazon.aws.hooks.sns import SnsHook
24+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
2425
from airflow.providers.common.compat.notifier import BaseNotifier
2526

2627

@@ -60,8 +61,13 @@ def __init__(
6061
subject: str | None = None,
6162
message_attributes: dict | None = None,
6263
region_name: str | None = None,
64+
**kwargs,
6365
):
64-
super().__init__()
66+
if AIRFLOW_V_3_1_PLUS:
67+
# Support for passing context was added in 3.1.0
68+
super().__init__(**kwargs)
69+
else:
70+
super().__init__()
6571
self.aws_conn_id = aws_conn_id
6672
self.region_name = region_name
6773
self.target_arn = target_arn
@@ -83,5 +89,14 @@ def notify(self, context):
8389
message_attributes=self.message_attributes,
8490
)
8591

92+
async def async_notify(self, context):
93+
"""Publish the notification message to Amazon SNS (async)."""
94+
await self.hook.apublish_to_target(
95+
target_arn=self.target_arn,
96+
message=self.message,
97+
subject=self.subject,
98+
message_attributes=self.message_attributes,
99+
)
100+
86101

87102
send_sns_notification = SnsNotifier

providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py

Lines changed: 102 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,95 +17,75 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20+
from unittest import mock
21+
2022
import pytest
2123
from moto import mock_aws
2224

2325
from airflow.providers.amazon.aws.hooks.sns import SnsHook
2426

27+
DEDUPE_ID = "test-dedupe-id"
28+
GROUP_ID = "test-group-id"
2529
MESSAGE = "Hello world"
26-
TOPIC_NAME = "test-topic"
2730
SUBJECT = "test-subject"
31+
INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable"
2832

33+
TOPIC_NAME = "test-topic"
34+
TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{TOPIC_NAME}"
2935

30-
@mock_aws
31-
class TestSnsHook:
32-
def test_get_conn_returns_a_boto3_connection(self):
33-
hook = SnsHook(aws_conn_id="aws_default")
34-
assert hook.get_conn() is not None
35-
36-
def test_publish_to_target_with_subject(self):
37-
hook = SnsHook(aws_conn_id="aws_default")
38-
39-
message = MESSAGE
40-
topic_name = TOPIC_NAME
41-
subject = SUBJECT
42-
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
43-
44-
response = hook.publish_to_target(target, message, subject)
36+
INVALID_ATTRIBUTES = {"test-non-iterable": object()}
37+
VALID_ATTRIBUTES = {
38+
"test-string": "string-value",
39+
"test-number": 123456,
40+
"test-array": ["first", "second", "third"],
41+
"test-binary": b"binary-value",
42+
}
4543

46-
assert "MessageId" in response
44+
MESSAGE_ID_KEY = "MessageId"
45+
TOPIC_ARN_KEY = "TopicArn"
4746

48-
def test_publish_to_target_with_attributes(self):
49-
hook = SnsHook(aws_conn_id="aws_default")
5047

51-
message = MESSAGE
52-
topic_name = TOPIC_NAME
53-
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
48+
class TestSnsHook:
49+
@pytest.fixture(autouse=True)
50+
def setup_moto(self):
51+
with mock_aws():
52+
yield
5453

55-
response = hook.publish_to_target(
56-
target,
57-
message,
58-
message_attributes={
59-
"test-string": "string-value",
60-
"test-number": 123456,
61-
"test-array": ["first", "second", "third"],
62-
"test-binary": b"binary-value",
63-
},
64-
)
54+
@pytest.fixture
55+
def hook(self):
56+
return SnsHook(aws_conn_id="aws_default")
6557

66-
assert "MessageId" in response
58+
@pytest.fixture
59+
def target(self, hook):
60+
return hook.get_conn().create_topic(Name=TOPIC_NAME).get(TOPIC_ARN_KEY)
6761

68-
def test_publish_to_target_plain(self):
69-
hook = SnsHook(aws_conn_id="aws_default")
62+
def test_get_conn_returns_a_boto3_connection(self, hook):
63+
assert hook.get_conn() is not None
7064

71-
message = MESSAGE
72-
topic_name = "test-topic"
73-
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
65+
def test_publish_to_target_with_subject(self, hook, target):
66+
response = hook.publish_to_target(target, MESSAGE, SUBJECT)
7467

75-
response = hook.publish_to_target(target, message)
68+
assert MESSAGE_ID_KEY in response
7669

77-
assert "MessageId" in response
70+
def test_publish_to_target_with_attributes(self, hook, target):
71+
response = hook.publish_to_target(target, MESSAGE, message_attributes=VALID_ATTRIBUTES)
7872

79-
def test_publish_to_target_error(self):
80-
hook = SnsHook(aws_conn_id="aws_default")
73+
assert MESSAGE_ID_KEY in response
8174

82-
message = "Hello world"
83-
topic_name = "test-topic"
84-
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
75+
def test_publish_to_target_plain(self, hook, target):
76+
response = hook.publish_to_target(target, MESSAGE)
8577

86-
error_message = (
87-
r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got .*"
88-
)
89-
with pytest.raises(TypeError, match=error_message):
90-
hook.publish_to_target(
91-
target,
92-
message,
93-
message_attributes={
94-
"test-non-iterable": object(),
95-
},
96-
)
78+
assert MESSAGE_ID_KEY in response
9779

98-
def test_publish_to_target_with_deduplication(self):
99-
hook = SnsHook(aws_conn_id="aws_default")
80+
def test_publish_to_target_error(self, hook, target):
81+
with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
82+
hook.publish_to_target(target, MESSAGE, message_attributes=INVALID_ATTRIBUTES)
10083

101-
message = MESSAGE
102-
topic_name = TOPIC_NAME + ".fifo"
103-
deduplication_id = "abc"
104-
group_id = "a"
105-
target = (
84+
def test_publish_to_target_with_deduplication(self, hook):
85+
fifo_target = (
10686
hook.get_conn()
10787
.create_topic(
108-
Name=topic_name,
88+
Name=f"{TOPIC_NAME}.fifo",
10989
Attributes={
11090
"FifoTopic": "true",
11191
"ContentBasedDeduplication": "false",
@@ -115,7 +95,63 @@ def test_publish_to_target_with_deduplication(self):
11595
)
11696

11797
response = hook.publish_to_target(
118-
target, message, message_deduplication_id=deduplication_id, message_group_id=group_id
98+
fifo_target, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID
99+
)
100+
assert MESSAGE_ID_KEY in response
101+
102+
103+
@pytest.mark.asyncio
104+
class TestAsyncSnsHook:
105+
"""The mock_aws decorator uses `moto` which does not currently support async SNS so we mock it manually."""
106+
107+
@pytest.fixture
108+
def hook(self):
109+
return SnsHook(aws_conn_id="aws_default")
110+
111+
@pytest.fixture
112+
def mock_async_client(self):
113+
mock_client = mock.AsyncMock()
114+
mock_client.publish.return_value = {MESSAGE_ID_KEY: "test-message-id"}
115+
return mock_client
116+
117+
@pytest.fixture
118+
def mock_get_async_conn(self, mock_async_client):
119+
with mock.patch.object(SnsHook, "get_async_conn") as mocked_conn:
120+
mocked_conn.return_value = mock_async_client
121+
mocked_conn.return_value.__aenter__.return_value = mock_async_client
122+
yield mocked_conn
123+
124+
async def test_get_async_conn(self, hook, mock_get_async_conn, mock_async_client):
125+
# Test context manager access
126+
async with await hook.get_async_conn() as async_conn:
127+
assert async_conn is mock_async_client
128+
129+
# Test direct access
130+
async_conn = await hook.get_async_conn()
131+
assert async_conn is mock_async_client
132+
133+
async def test_apublish_to_target_with_subject(self, hook, mock_get_async_conn, mock_async_client):
134+
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, SUBJECT)
135+
136+
assert MESSAGE_ID_KEY in response
137+
138+
async def test_apublish_to_target_with_attributes(self, hook, mock_get_async_conn, mock_async_client):
139+
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=VALID_ATTRIBUTES)
140+
141+
assert MESSAGE_ID_KEY in response
142+
143+
async def test_publish_to_target_plain(self, hook, mock_get_async_conn, mock_async_client):
144+
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE)
145+
146+
assert MESSAGE_ID_KEY in response
147+
148+
async def test_publish_to_target_error(self, hook, mock_get_async_conn, mock_async_client):
149+
with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
150+
await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=INVALID_ATTRIBUTES)
151+
152+
async def test_apublish_to_target_with_deduplication(self, hook, mock_get_async_conn, mock_async_client):
153+
response = await hook.apublish_to_target(
154+
TOPIC_ARN, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID
119155
)
120156

121-
assert "MessageId" in response
157+
assert MESSAGE_ID_KEY in response

0 commit comments

Comments
 (0)