Skip to content

Commit 0ac0100

Browse files
committed
refactor: Move logic for getting regions from the config to common.py
Currently, the `region` key can be configured in the image configuration. However, sibce `sns` also support the `region` key to send a notifications to specific regions, this logic can be resued to get the sns regions. Move the logic to common.py and rename to `_get_regions` as a more general function for retrieving regions from different keys.
1 parent 1789f05 commit 0ac0100

File tree

4 files changed

+92
-27
lines changed

4 files changed

+92
-27
lines changed

awspub/common.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from typing import Tuple
1+
import logging
2+
from typing import List, Tuple
3+
4+
import boto3
5+
from mypy_boto3_ec2.client import EC2Client
6+
7+
logger = logging.getLogger(__name__)
28

39

410
def _split_partition(val: str) -> Tuple[str, str]:
@@ -16,3 +22,43 @@ def _split_partition(val: str) -> Tuple[str, str]:
1622
partition = "aws"
1723
resource = val
1824
return partition, resource
25+
26+
27+
def _get_regions(region_to_query: str, regions_allowlist: List[str]) -> List[str]:
28+
"""
29+
Get a list of region names querying the `region_to_query` for all regions and
30+
then filtering by `regions_allowlist`.
31+
If no `regions_allowlist` is given, all queried regions are returned for the
32+
current partition.
33+
If `regions_allowlist` is given, all regions from that list are returned if
34+
the listed region exist in the current partition.
35+
Eg. `us-east-1` listed in `regions_allowlist` won't be returned if the current
36+
partition is `aws-cn`.
37+
:param region_to_query: region name of current partition
38+
:type region_to_query: str
39+
:praram regions_allowlist: list of regions in config file
40+
:type regions_allowlist: List[str]
41+
:return: list of regions names
42+
:rtype: List[str]
43+
"""
44+
45+
# get all available regions
46+
ec2client: EC2Client = boto3.client("ec2", region_name=region_to_query)
47+
resp = ec2client.describe_regions()
48+
ec2_regions_all = [r["RegionName"] for r in resp["Regions"]]
49+
50+
if regions_allowlist:
51+
# filter out regions that are not available in the current partition
52+
regions_allowlist_set = set(regions_allowlist)
53+
ec2_regions_all_set = set(ec2_regions_all)
54+
regions = list(regions_allowlist_set.intersection(ec2_regions_all_set))
55+
diff = regions_allowlist_set.difference(ec2_regions_all_set)
56+
if diff:
57+
logger.warning(
58+
f"regions {diff} listed in regions allowlist are not available in the current partition."
59+
" Ignoring those."
60+
)
61+
else:
62+
regions = ec2_regions_all
63+
64+
return regions

awspub/image.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mypy_boto3_ssm import SSMClient
1010

1111
from awspub import exceptions
12-
from awspub.common import _split_partition
12+
from awspub.common import _get_regions, _split_partition
1313
from awspub.context import Context
1414
from awspub.image_marketplace import ImageMarketplace
1515
from awspub.s3 import S3
@@ -121,28 +121,11 @@ def snapshot_name(self) -> str:
121121
@property
122122
def image_regions(self) -> List[str]:
123123
"""
124-
Get the image regions. Either configured in the image configuration
125-
or all available regions.
126-
If a region is listed that is not available in the currently used partition,
127-
that region will be ignored (eg. having us-east-1 configured but running in the aws-cn
128-
partition doesn't include us-east-1 here).
124+
Get the image regions.
129125
"""
130126
if not self._image_regions_cached:
131-
# get all available regions
132-
ec2client: EC2Client = boto3.client("ec2", region_name=self._s3.bucket_region)
133-
resp = ec2client.describe_regions()
134-
image_regions_all = [r["RegionName"] for r in resp["Regions"]]
135-
136-
if self.conf["regions"]:
137-
# filter out regions that are not available in the current partition
138-
image_regions_configured_set = set(self.conf["regions"])
139-
image_regions_all_set = set(image_regions_all)
140-
self._image_regions = list(image_regions_configured_set.intersection(image_regions_all_set))
141-
diff = image_regions_configured_set.difference(image_regions_all_set)
142-
if diff:
143-
logger.warning(f"configured regions {diff} not available in the current partition. Ignoring those.")
144-
else:
145-
self._image_regions = image_regions_all
127+
regions_configured = self.conf["regions"] if "regions" in self.conf else []
128+
self._image_regions = _get_regions(self._s3.bucket_region, regions_configured)
146129
self._image_regions_cached = True
147130
return self._image_regions
148131

@@ -358,14 +341,16 @@ def _sns_publish(self) -> None:
358341
"""
359342
Publish SNS notifiations about newly available images to subscribers
360343
"""
344+
345+
# Checking if the image(s) are registered or published before sending the notification
361346
for region in self.image_regions:
362347
ec2client_region: EC2Client = boto3.client("ec2", region_name=region)
363348
image_info: Optional[_ImageInfo] = self._get(ec2client_region)
364-
365349
if not image_info:
366350
logger.error(f"can not send SNS notification for {self.image_name} because no image found in {region}")
367351
return
368-
SNSNotification(self._ctx, self.image_name, region).publish()
352+
353+
SNSNotification(self._ctx, self.image_name, self._s3.bucket_region).publish()
369354

370355
def cleanup(self) -> None:
371356
"""

awspub/sns.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from mypy_boto3_sns.client import SNSClient
1212
from mypy_boto3_sts.client import STSClient
1313

14+
from awspub.common import _get_regions
1415
from awspub.context import Context
1516
from awspub.exceptions import AWSAuthorizationException, AWSNotificationException
17+
from awspub.s3 import S3
1618

1719
logger = logging.getLogger(__name__)
1820

@@ -23,13 +25,13 @@ class SNSNotification(object):
2325
structuring rules for SNS notification JSON
2426
"""
2527

26-
def __init__(self, context: Context, image_name: str, region_name: str):
28+
def __init__(self, context: Context, image_name: str):
2729
"""
2830
Construct a message and verify that it is valid
2931
"""
3032
self._ctx: Context = context
3133
self._image_name: str = image_name
32-
self._region_name: str = region_name
34+
self._s3: S3 = S3(context)
3335

3436
@property
3537
def conf(self) -> List[Dict[str, Any]]:
@@ -39,6 +41,20 @@ def conf(self) -> List[Dict[str, Any]]:
3941
return self._ctx.conf["images"][self._image_name]["sns"]
4042

4143
def _get_topic_arn(self, topic_name: str) -> str:
44+
def _sns_regions(self, topic_config: Dict[Any, Any]) -> List[str]:
45+
"""
46+
Get the sns regions. Either configured in the sns configuration
47+
or all available regions.
48+
If a region is listed that is not available in the currently used partition,
49+
that region will be ignored (eg. having us-east-1 configured but running in the aws-cn
50+
partition doesn't include us-east-1 here).
51+
"""
52+
53+
regions_configured = topic_config["regions"] if "regions" in topic_config else []
54+
sns_regions = _get_regions(self._s3.bucket_region, regions_configured)
55+
56+
return sns_regions
57+
4258
"""
4359
Calculate topic ARN based on partition, region, account and topic name
4460
:param topic_name: Name of topic

awspub/tests/test_common.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

3-
from awspub.common import _split_partition
5+
from awspub.common import _get_regions, _split_partition
46

57

68
@pytest.mark.parametrize(
@@ -14,3 +16,19 @@
1416
)
1517
def test_common__split_partition(input, expected_output):
1618
assert _split_partition(input) == expected_output
19+
20+
21+
@pytest.mark.parametrize(
22+
"regions_in_partition,configured_regions,expected_output",
23+
[
24+
(["region-1", "region-2"], ["region-1", "region-3"], ["region-1"]),
25+
(["region-1", "region-2", "region-3"], ["region-4", "region-5"], []),
26+
(["region-1", "region-2"], [], ["region-1", "region-2"]),
27+
],
28+
)
29+
def test_common__get_regions(regions_in_partition, configured_regions, expected_output):
30+
with patch("boto3.client") as bclient_mock:
31+
instance = bclient_mock.return_value
32+
instance.describe_regions.return_value = {"Regions": [{"RegionName": r} for r in regions_in_partition]}
33+
34+
assert _get_regions("", configured_regions) == expected_output

0 commit comments

Comments
 (0)