Skip to content

Commit 7355655

Browse files
authored
Merge pull request #74 from JessicaJang/sns-region-setting
fix: Allow SNS config has separate list of regions
2 parents 0e4ef82 + 31964b4 commit 7355655

File tree

11 files changed

+221
-83
lines changed

11 files changed

+221
-83
lines changed

awspub/common.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from typing import Tuple
1+
import logging
2+
from typing import List, Tuple
3+
4+
import boto3
5+
from mypy_boto3_ec2.client import EC2Client
6+
7+
logger = logging.getLogger(__name__)
28

39

410
def _split_partition(val: str) -> Tuple[str, str]:
@@ -16,3 +22,43 @@ def _split_partition(val: str) -> Tuple[str, str]:
1622
partition = "aws"
1723
resource = val
1824
return partition, resource
25+
26+
27+
def _get_regions(region_to_query: str, regions_allowlist: List[str]) -> List[str]:
28+
"""
29+
Get a list of region names querying the `region_to_query` for all regions and
30+
then filtering by `regions_allowlist`.
31+
If no `regions_allowlist` is given, all queried regions are returned for the
32+
current partition.
33+
If `regions_allowlist` is given, all regions from that list are returned if
34+
the listed region exist in the current partition.
35+
Eg. `us-east-1` listed in `regions_allowlist` won't be returned if the current
36+
partition is `aws-cn`.
37+
:param region_to_query: region name of current partition
38+
:type region_to_query: str
39+
:praram regions_allowlist: list of regions in config file
40+
:type regions_allowlist: List[str]
41+
:return: list of regions names
42+
:rtype: List[str]
43+
"""
44+
45+
# get all available regions
46+
ec2client: EC2Client = boto3.client("ec2", region_name=region_to_query)
47+
resp = ec2client.describe_regions()
48+
ec2_regions_all = [r["RegionName"] for r in resp["Regions"]]
49+
50+
if regions_allowlist:
51+
# filter out regions that are not available in the current partition
52+
regions_allowlist_set = set(regions_allowlist)
53+
ec2_regions_all_set = set(ec2_regions_all)
54+
regions = list(regions_allowlist_set.intersection(ec2_regions_all_set))
55+
diff = regions_allowlist_set.difference(ec2_regions_all_set)
56+
if diff:
57+
logger.warning(
58+
f"regions {diff} listed in regions allowlist are not available in the current partition."
59+
" Ignoring those."
60+
)
61+
else:
62+
regions = ec2_regions_all
63+
64+
return regions

awspub/configmodels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ class ConfigImageSNSNotificationModel(BaseModel):
112112
description="The body of the message to be sent to subscribers.",
113113
default={SNSNotificationProtocol.DEFAULT: ""},
114114
)
115+
regions: Optional[List[str]] = Field(
116+
description="Optional list of regions for sending notification. If not given, regions where the image "
117+
"registered will be used from the currently used parition. If a region doesn't exist in the currently "
118+
"used partition, it will be ignored.",
119+
default=None,
120+
)
115121

116122
@field_validator("message")
117123
def check_message(cls, value):

awspub/image.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mypy_boto3_ssm import SSMClient
1010

1111
from awspub import exceptions
12-
from awspub.common import _split_partition
12+
from awspub.common import _get_regions, _split_partition
1313
from awspub.context import Context
1414
from awspub.image_marketplace import ImageMarketplace
1515
from awspub.s3 import S3
@@ -121,28 +121,11 @@ def snapshot_name(self) -> str:
121121
@property
122122
def image_regions(self) -> List[str]:
123123
"""
124-
Get the image regions. Either configured in the image configuration
125-
or all available regions.
126-
If a region is listed that is not available in the currently used partition,
127-
that region will be ignored (eg. having us-east-1 configured but running in the aws-cn
128-
partition doesn't include us-east-1 here).
124+
Get the image regions.
129125
"""
130126
if not self._image_regions_cached:
131-
# get all available regions
132-
ec2client: EC2Client = boto3.client("ec2", region_name=self._s3.bucket_region)
133-
resp = ec2client.describe_regions()
134-
image_regions_all = [r["RegionName"] for r in resp["Regions"]]
135-
136-
if self.conf["regions"]:
137-
# filter out regions that are not available in the current partition
138-
image_regions_configured_set = set(self.conf["regions"])
139-
image_regions_all_set = set(image_regions_all)
140-
self._image_regions = list(image_regions_configured_set.intersection(image_regions_all_set))
141-
diff = image_regions_configured_set.difference(image_regions_all_set)
142-
if diff:
143-
logger.warning(f"configured regions {diff} not available in the current partition. Ignoring those.")
144-
else:
145-
self._image_regions = image_regions_all
127+
regions_configured = self.conf["regions"] if "regions" in self.conf else []
128+
self._image_regions = _get_regions(self._s3.bucket_region, regions_configured)
146129
self._image_regions_cached = True
147130
return self._image_regions
148131

@@ -358,14 +341,8 @@ def _sns_publish(self) -> None:
358341
"""
359342
Publish SNS notifiations about newly available images to subscribers
360343
"""
361-
for region in self.image_regions:
362-
ec2client_region: EC2Client = boto3.client("ec2", region_name=region)
363-
image_info: Optional[_ImageInfo] = self._get(ec2client_region)
364344

365-
if not image_info:
366-
logger.error(f"can not send SNS notification for {self.image_name} because no image found in {region}")
367-
return
368-
SNSNotification(self._ctx, self.image_name, region).publish()
345+
SNSNotification(self._ctx, self.image_name).publish()
369346

370347
def cleanup(self) -> None:
371348
"""

awspub/sns.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from mypy_boto3_sns.client import SNSClient
1212
from mypy_boto3_sts.client import STSClient
1313

14+
from awspub.common import _get_regions
1415
from awspub.context import Context
1516
from awspub.exceptions import AWSAuthorizationException, AWSNotificationException
17+
from awspub.s3 import S3
1618

1719
logger = logging.getLogger(__name__)
1820

@@ -23,13 +25,13 @@ class SNSNotification(object):
2325
structuring rules for SNS notification JSON
2426
"""
2527

26-
def __init__(self, context: Context, image_name: str, region_name: str):
28+
def __init__(self, context: Context, image_name: str):
2729
"""
2830
Construct a message and verify that it is valid
2931
"""
3032
self._ctx: Context = context
3133
self._image_name: str = image_name
32-
self._region_name: str = region_name
34+
self._s3: S3 = S3(context)
3335

3436
@property
3537
def conf(self) -> List[Dict[str, Any]]:
@@ -38,7 +40,21 @@ def conf(self) -> List[Dict[str, Any]]:
3840
"""
3941
return self._ctx.conf["images"][self._image_name]["sns"]
4042

41-
def _get_topic_arn(self, topic_name: str) -> str:
43+
def _sns_regions(self, topic_config: Dict[Any, Any]) -> List[str]:
44+
"""
45+
Get the sns regions. Either configured in the sns configuration
46+
or all available regions.
47+
If a region is listed that is not available in the currently used partition,
48+
that region will be ignored (eg. having us-east-1 configured but running in the aws-cn
49+
partition doesn't include us-east-1 here).
50+
"""
51+
52+
regions_configured = topic_config["regions"] if "regions" in topic_config else []
53+
sns_regions = _get_regions(self._s3.bucket_region, regions_configured)
54+
55+
return sns_regions
56+
57+
def _get_topic_arn(self, topic_name: str, region_name: str) -> str:
4258
"""
4359
Calculate topic ARN based on partition, region, account and topic name
4460
:param topic_name: Name of topic
@@ -49,40 +65,41 @@ def _get_topic_arn(self, topic_name: str) -> str:
4965
:rtype: str
5066
"""
5167

52-
stsclient: STSClient = boto3.client("sts", region_name=self._region_name)
68+
stsclient: STSClient = boto3.client("sts", region_name=region_name)
5369
resp = stsclient.get_caller_identity()
5470

5571
account = resp["Account"]
5672
# resp["Arn"] has string format "arn:partition:iam::accountnumber:user/iam_role"
5773
partition = resp["Arn"].rsplit(":")[1]
5874

59-
return f"arn:{partition}:sns:{self._region_name}:{account}:{topic_name}"
75+
return f"arn:{partition}:sns:{region_name}:{account}:{topic_name}"
6076

6177
def publish(self) -> None:
6278
"""
6379
send notification to subscribers
6480
"""
6581

66-
snsclient: SNSClient = boto3.client("sns", region_name=self._region_name)
67-
6882
for topic in self.conf:
6983
for topic_name, topic_config in topic.items():
70-
try:
71-
snsclient.publish(
72-
TopicArn=self._get_topic_arn(topic_name),
73-
Subject=topic_config["subject"],
74-
Message=json.dumps(topic_config["message"]),
75-
MessageStructure="json",
76-
)
77-
except ClientError as e:
78-
exception_code: str = e.response["Error"]["Code"]
79-
if exception_code == "AuthorizationError":
80-
raise AWSAuthorizationException(
81-
"Profile does not have a permission to send the SNS notification. Please review the policy."
84+
for region_name in self._sns_regions(topic_config):
85+
snsclient: SNSClient = boto3.client("sns", region_name=region_name)
86+
try:
87+
snsclient.publish(
88+
TopicArn=self._get_topic_arn(topic_name, region_name),
89+
Subject=topic_config["subject"],
90+
Message=json.dumps(topic_config["message"]),
91+
MessageStructure="json",
8292
)
83-
else:
84-
raise AWSNotificationException(str(e))
85-
logger.info(
86-
f"The SNS notification {topic_config['subject']}"
87-
f" for the topic {topic_name} in {self._region_name} has been sent."
88-
)
93+
except ClientError as e:
94+
exception_code: str = e.response["Error"]["Code"]
95+
if exception_code == "AuthorizationError":
96+
raise AWSAuthorizationException(
97+
"Profile does not have a permission to send the SNS notification."
98+
" Please review the policy."
99+
)
100+
else:
101+
raise AWSNotificationException(str(e))
102+
logger.info(
103+
f"The SNS notification {topic_config['subject']}"
104+
f" for the topic {topic_name} in {region_name} has been sent."
105+
)

awspub/tests/fixtures/config1.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ awspub:
130130
message:
131131
default: "default-message"
132132
email: "email-message"
133+
regions:
134+
- "us-east-1"
133135
"test-image-11":
134136
boot_mode: "uefi"
135137
description: |
@@ -143,10 +145,27 @@ awspub:
143145
message:
144146
default: "default-message"
145147
email: "email-message"
148+
regions:
149+
- "us-east-1"
146150
- "topic2":
147151
subject: "topic2-subject"
148152
message:
149153
default: "default-message"
154+
regions:
155+
- "us-gov-1"
156+
- "eu-central-1"
157+
"test-image-12":
158+
boot_mode: "uefi"
159+
description: |
160+
A test image without a separate snapshot but single sns configs
161+
regions:
162+
- "us-east-1"
163+
sns:
164+
- "topic1":
165+
subject: "topic1-subject"
166+
message:
167+
default: "default-message"
168+
email: "email-message"
150169

151170
tags:
152171
name: "foobar"

awspub/tests/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"test-image-9",
2626
"test-image-10",
2727
"test-image-11",
28+
"test-image-12",
2829
],
2930
),
3031
# with a group that no image as, no image should be processed

awspub/tests/test_common.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

3-
from awspub.common import _split_partition
5+
from awspub.common import _get_regions, _split_partition
46

57

68
@pytest.mark.parametrize(
@@ -14,3 +16,19 @@
1416
)
1517
def test_common__split_partition(input, expected_output):
1618
assert _split_partition(input) == expected_output
19+
20+
21+
@pytest.mark.parametrize(
22+
"regions_in_partition,configured_regions,expected_output",
23+
[
24+
(["region-1", "region-2"], ["region-1", "region-3"], ["region-1"]),
25+
(["region-1", "region-2", "region-3"], ["region-4", "region-5"], []),
26+
(["region-1", "region-2"], [], ["region-1", "region-2"]),
27+
],
28+
)
29+
def test_common__get_regions(regions_in_partition, configured_regions, expected_output):
30+
with patch("boto3.client") as bclient_mock:
31+
instance = bclient_mock.return_value
32+
instance.describe_regions.return_value = {"Regions": [{"RegionName": r} for r in regions_in_partition]}
33+
34+
assert _get_regions("", configured_regions) == expected_output

awspub/tests/test_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_image___get_root_device_snapshot_id(root_device_name, block_device_mapp
143143
("test-image-8", "aws-cn", True, True, False, True, False),
144144
("test-image-10", "aws", False, False, False, False, True),
145145
("test-image-11", "aws", False, False, False, False, True),
146+
("test-image-12", "aws", False, False, False, False, True),
146147
],
147148
)
148149
def test_image_publish(
@@ -183,7 +184,6 @@ def test_image_publish(
183184
"Regions": [{"RegionName": "eu-central-1"}, {"RegionName": "us-east-1"}]
184185
}
185186
instance.list_buckets.return_value = {"Buckets": [{"Name": "bucket1"}]}
186-
instance.list_topics.return_value = {"Topics": [{"TopicArn": "arn:aws:sns:topic1"}]}
187187
ctx = context.Context(curdir / "fixtures/config1.yaml", None)
188188
img = image.Image(ctx, imagename)
189189
img.publish()

0 commit comments

Comments
 (0)