1717# under the License.
1818from __future__ import annotations
1919
20+ from unittest import mock
21+
2022import pytest
2123from moto import mock_aws
2224
2325from airflow .providers .amazon .aws .hooks .sns import SnsHook
2426
27+ DEDUPE_ID = "test-dedupe-id"
28+ GROUP_ID = "test-group-id"
2529MESSAGE = "Hello world"
26- TOPIC_NAME = "test-topic"
2730SUBJECT = "test-subject"
31+ INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable"
2832
33+ TOPIC_NAME = "test-topic"
34+ TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{ TOPIC_NAME } "
2935
30- @mock_aws
31- class TestSnsHook :
32- def test_get_conn_returns_a_boto3_connection (self ):
33- hook = SnsHook (aws_conn_id = "aws_default" )
34- assert hook .get_conn () is not None
35-
36- def test_publish_to_target_with_subject (self ):
37- hook = SnsHook (aws_conn_id = "aws_default" )
38-
39- message = MESSAGE
40- topic_name = TOPIC_NAME
41- subject = SUBJECT
42- target = hook .get_conn ().create_topic (Name = topic_name ).get ("TopicArn" )
43-
44- response = hook .publish_to_target (target , message , subject )
36+ INVALID_ATTRIBUTES = {"test-non-iterable" : object ()}
37+ VALID_ATTRIBUTES = {
38+ "test-string" : "string-value" ,
39+ "test-number" : 123456 ,
40+ "test-array" : ["first" , "second" , "third" ],
41+ "test-binary" : b"binary-value" ,
42+ }
4543
46- assert "MessageId" in response
44+ MESSAGE_ID_KEY = "MessageId"
45+ TOPIC_ARN_KEY = "TopicArn"
4746
48- def test_publish_to_target_with_attributes (self ):
49- hook = SnsHook (aws_conn_id = "aws_default" )
5047
51- message = MESSAGE
52- topic_name = TOPIC_NAME
53- target = hook .get_conn ().create_topic (Name = topic_name ).get ("TopicArn" )
48+ class TestSnsHook :
49+ @pytest .fixture (autouse = True )
50+ def setup_moto (self ):
51+ with mock_aws ():
52+ yield
5453
55- response = hook .publish_to_target (
56- target ,
57- message ,
58- message_attributes = {
59- "test-string" : "string-value" ,
60- "test-number" : 123456 ,
61- "test-array" : ["first" , "second" , "third" ],
62- "test-binary" : b"binary-value" ,
63- },
64- )
54+ @pytest .fixture
55+ def hook (self ):
56+ return SnsHook (aws_conn_id = "aws_default" )
6557
66- assert "MessageId" in response
58+ @pytest .fixture
59+ def target (self , hook ):
60+ return hook .get_conn ().create_topic (Name = TOPIC_NAME ).get (TOPIC_ARN_KEY )
6761
68- def test_publish_to_target_plain (self ):
69- hook = SnsHook ( aws_conn_id = "aws_default" )
62+ def test_get_conn_returns_a_boto3_connection (self , hook ):
63+ assert hook . get_conn () is not None
7064
71- message = MESSAGE
72- topic_name = "test-topic"
73- target = hook .get_conn ().create_topic (Name = topic_name ).get ("TopicArn" )
65+ def test_publish_to_target_with_subject (self , hook , target ):
66+ response = hook .publish_to_target (target , MESSAGE , SUBJECT )
7467
75- response = hook . publish_to_target ( target , message )
68+ assert MESSAGE_ID_KEY in response
7669
77- assert "MessageId" in response
70+ def test_publish_to_target_with_attributes (self , hook , target ):
71+ response = hook .publish_to_target (target , MESSAGE , message_attributes = VALID_ATTRIBUTES )
7872
79- def test_publish_to_target_error (self ):
80- hook = SnsHook (aws_conn_id = "aws_default" )
73+ assert MESSAGE_ID_KEY in response
8174
82- message = "Hello world"
83- topic_name = "test-topic"
84- target = hook .get_conn ().create_topic (Name = topic_name ).get ("TopicArn" )
75+ def test_publish_to_target_plain (self , hook , target ):
76+ response = hook .publish_to_target (target , MESSAGE )
8577
86- error_message = (
87- r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got .*"
88- )
89- with pytest .raises (TypeError , match = error_message ):
90- hook .publish_to_target (
91- target ,
92- message ,
93- message_attributes = {
94- "test-non-iterable" : object (),
95- },
96- )
78+ assert MESSAGE_ID_KEY in response
9779
98- def test_publish_to_target_with_deduplication (self ):
99- hook = SnsHook (aws_conn_id = "aws_default" )
80+ def test_publish_to_target_error (self , hook , target ):
81+ with pytest .raises (TypeError , match = INVALID_ATTRIBUTES_MSG ):
82+ hook .publish_to_target (target , MESSAGE , message_attributes = INVALID_ATTRIBUTES )
10083
101- message = MESSAGE
102- topic_name = TOPIC_NAME + ".fifo"
103- deduplication_id = "abc"
104- group_id = "a"
105- target = (
84+ def test_publish_to_target_with_deduplication (self , hook ):
85+ fifo_target = (
10686 hook .get_conn ()
10787 .create_topic (
108- Name = topic_name ,
88+ Name = f" { TOPIC_NAME } .fifo" ,
10989 Attributes = {
11090 "FifoTopic" : "true" ,
11191 "ContentBasedDeduplication" : "false" ,
@@ -115,7 +95,63 @@ def test_publish_to_target_with_deduplication(self):
11595 )
11696
11797 response = hook .publish_to_target (
118- target , message , message_deduplication_id = deduplication_id , message_group_id = group_id
98+ fifo_target , MESSAGE , message_deduplication_id = DEDUPE_ID , message_group_id = GROUP_ID
99+ )
100+ assert MESSAGE_ID_KEY in response
101+
102+
103+ @pytest .mark .asyncio
104+ class TestAsyncSnsHook :
105+ """The mock_aws decorator uses `moto` which does not currently support async SNS so we mock it manually."""
106+
107+ @pytest .fixture
108+ def hook (self ):
109+ return SnsHook (aws_conn_id = "aws_default" )
110+
111+ @pytest .fixture
112+ def mock_async_client (self ):
113+ mock_client = mock .AsyncMock ()
114+ mock_client .publish .return_value = {MESSAGE_ID_KEY : "test-message-id" }
115+ return mock_client
116+
117+ @pytest .fixture
118+ def mock_get_async_conn (self , mock_async_client ):
119+ with mock .patch .object (SnsHook , "get_async_conn" ) as mocked_conn :
120+ mocked_conn .return_value = mock_async_client
121+ mocked_conn .return_value .__aenter__ .return_value = mock_async_client
122+ yield mocked_conn
123+
124+ async def test_get_async_conn (self , hook , mock_get_async_conn , mock_async_client ):
125+ # Test context manager access
126+ async with await hook .get_async_conn () as async_conn :
127+ assert async_conn is mock_async_client
128+
129+ # Test direct access
130+ async_conn = await hook .get_async_conn ()
131+ assert async_conn is mock_async_client
132+
133+ async def test_apublish_to_target_with_subject (self , hook , mock_get_async_conn , mock_async_client ):
134+ response = await hook .apublish_to_target (TOPIC_ARN , MESSAGE , SUBJECT )
135+
136+ assert MESSAGE_ID_KEY in response
137+
138+ async def test_apublish_to_target_with_attributes (self , hook , mock_get_async_conn , mock_async_client ):
139+ response = await hook .apublish_to_target (TOPIC_ARN , MESSAGE , message_attributes = VALID_ATTRIBUTES )
140+
141+ assert MESSAGE_ID_KEY in response
142+
143+ async def test_publish_to_target_plain (self , hook , mock_get_async_conn , mock_async_client ):
144+ response = await hook .apublish_to_target (TOPIC_ARN , MESSAGE )
145+
146+ assert MESSAGE_ID_KEY in response
147+
148+ async def test_publish_to_target_error (self , hook , mock_get_async_conn , mock_async_client ):
149+ with pytest .raises (TypeError , match = INVALID_ATTRIBUTES_MSG ):
150+ await hook .apublish_to_target (TOPIC_ARN , MESSAGE , message_attributes = INVALID_ATTRIBUTES )
151+
152+ async def test_apublish_to_target_with_deduplication (self , hook , mock_get_async_conn , mock_async_client ):
153+ response = await hook .apublish_to_target (
154+ TOPIC_ARN , MESSAGE , message_deduplication_id = DEDUPE_ID , message_group_id = GROUP_ID
119155 )
120156
121- assert "MessageId" in response
157+ assert MESSAGE_ID_KEY in response
0 commit comments