Skip to content

Commit 60023d0

Browse files
authored
infra: programmatically determine partition based on region (#1316)
1 parent bffd0f1 commit 60023d0

File tree

4 files changed

+85
-35
lines changed

4 files changed

+85
-35
lines changed

src/sagemaker/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import contextlib
1717
import errno
18+
import logging
1819
import os
1920
import random
2021
import re
@@ -38,6 +39,8 @@
3839
HTTPS_PREFIX = "https://"
3940
DEFAULT_SLEEP_TIME_SECONDS = 10
4041

42+
logger = logging.getLogger(__name__)
43+
4144

4245
# Use the base name of the image as the job name if the user doesn't give us one
4346
def name_from_image(image):
@@ -667,6 +670,21 @@ def _botocore_resolver():
667670
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
668671

669672

673+
def _aws_partition(region):
674+
"""
675+
Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
676+
677+
Args:
678+
region (str): The region name for which to return the corresponding partition.
679+
Ex: "cn-north-1"
680+
681+
Returns:
682+
str: partition corresponding to the region name passed in. Ex: "aws-cn"
683+
"""
684+
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
685+
return endpoint_data["partition"]
686+
687+
670688
class DeferredError(object):
671689
"""Stores an exception and raises it at a later time if this object is
672690
accessed in any way. Useful to allow soft-dependencies on imports, so that

tests/integ/kms_utils.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sagemaker import utils
1919

2020
PRINCIPAL_TEMPLATE = (
21-
'["{account_id}", "{role_arn}", ' '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] '
21+
'["{account_id}", "{role_arn}", ' '"arn:{partition}:iam::{account_id}:role/{sagemaker_role}"] '
2222
)
2323

2424
KEY_ALIAS = "SageMakerTestKMSKey"
@@ -60,11 +60,14 @@ def _get_kms_key_id(kms_client, alias):
6060

6161

6262
def _create_kms_key(
63-
kms_client, account_id, role_arn=None, sagemaker_role="SageMakerRole", alias=KEY_ALIAS
63+
kms_client, account_id, region, role_arn=None, sagemaker_role="SageMakerRole", alias=KEY_ALIAS
6464
):
6565
if role_arn:
6666
principal = PRINCIPAL_TEMPLATE.format(
67-
account_id=account_id, role_arn=role_arn, sagemaker_role=sagemaker_role
67+
partition=utils._aws_partition(region),
68+
account_id=account_id,
69+
role_arn=role_arn,
70+
sagemaker_role=sagemaker_role,
6871
)
6972
else:
7073
principal = '"{account_id}"'.format(account_id=account_id)
@@ -83,7 +86,7 @@ def _create_kms_key(
8386

8487

8588
def _add_role_to_policy(
86-
kms_client, account_id, role_arn, alias=KEY_ALIAS, sagemaker_role="SageMakerRole"
89+
kms_client, account_id, role_arn, region, alias=KEY_ALIAS, sagemaker_role="SageMakerRole"
8790
):
8891
key_id = _get_kms_key_id(kms_client, alias)
8992
policy = kms_client.get_key_policy(KeyId=key_id, PolicyName=POLICY_NAME)
@@ -92,7 +95,10 @@ def _add_role_to_policy(
9295

9396
if role_arn not in principal or sagemaker_role not in principal:
9497
principal = PRINCIPAL_TEMPLATE.format(
95-
account_id=account_id, role_arn=role_arn, sagemaker_role=sagemaker_role
98+
partition=utils._aws_partition(region),
99+
account_id=account_id,
100+
role_arn=role_arn,
101+
sagemaker_role=sagemaker_role,
96102
)
97103

98104
kms_client.put_key_policy(
@@ -115,44 +121,44 @@ def get_or_create_kms_key(
115121
account_id = sts_client.get_caller_identity()["Account"]
116122

117123
if kms_key_arn is None:
118-
return _create_kms_key(kms_client, account_id, role_arn, sagemaker_role, alias)
124+
return _create_kms_key(kms_client, account_id, region, role_arn, sagemaker_role, alias)
119125

120126
if role_arn:
121-
_add_role_to_policy(kms_client, account_id, role_arn, alias, sagemaker_role)
127+
_add_role_to_policy(kms_client, account_id, role_arn, region, alias, sagemaker_role)
122128

123129
return kms_key_arn
124130

125131

126-
KMS_BUCKET_POLICY = """{
132+
KMS_BUCKET_POLICY = """{{
127133
"Version": "2012-10-17",
128134
"Id": "PutObjPolicy",
129135
"Statement": [
130-
{
136+
{{
131137
"Sid": "DenyIncorrectEncryptionHeader",
132138
"Effect": "Deny",
133139
"Principal": "*",
134140
"Action": "s3:PutObject",
135-
"Resource": "arn:aws:s3:::%s/*",
136-
"Condition": {
137-
"StringNotEquals": {
138-
"s3:x-amz-server-side-encryption": "aws:kms"
139-
}
140-
}
141-
},
142-
{
141+
"Resource": "arn:{partition}:s3:::{bucket_name}/*",
142+
"Condition": {{
143+
"StringNotEquals": {{
144+
"s3:x-amz-server-side-encryption": "{partition}:kms"
145+
}}
146+
}}
147+
}},
148+
{{
143149
"Sid": "DenyUnEncryptedObjectUploads",
144150
"Effect": "Deny",
145151
"Principal": "*",
146152
"Action": "s3:PutObject",
147-
"Resource": "arn:aws:s3:::%s/*",
148-
"Condition": {
149-
"Null": {
153+
"Resource": "arn:{partition}:s3:::{bucket_name}/*",
154+
"Condition": {{
155+
"Null": {{
150156
"s3:x-amz-server-side-encryption": "true"
151-
}
152-
}
153-
}
157+
}}
158+
}}
159+
}}
154160
]
155-
}"""
161+
}}"""
156162

157163

158164
@contextlib.contextmanager
@@ -167,7 +173,7 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
167173
role_arn = sts_client.get_caller_identity()["Arn"]
168174

169175
kms_client = boto_session.client("kms")
170-
kms_key_arn = _create_kms_key(kms_client, account, role_arn, sagemaker_role, None)
176+
kms_key_arn = _create_kms_key(kms_client, account, region, role_arn, sagemaker_role, None)
171177

172178
region = boto_session.region_name
173179
bucket_name = "sagemaker-{}-{}-with-kms".format(region, account)
@@ -181,7 +187,9 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
181187
"Rules": [
182188
{
183189
"ApplyServerSideEncryptionByDefault": {
184-
"SSEAlgorithm": "aws:kms",
190+
"SSEAlgorithm": "{partition}:kms".format(
191+
partition=utils._aws_partition(region)
192+
),
185193
"KMSMasterKeyID": kms_key_arn,
186194
}
187195
}
@@ -190,7 +198,10 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
190198
)
191199

192200
s3_client.put_bucket_policy(
193-
Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
201+
Bucket=bucket_name,
202+
Policy=KMS_BUCKET_POLICY.format(
203+
partition=utils._aws_partition(region), bucket_name=bucket_name
204+
),
194205
)
195206

196207
yield "s3://" + bucket_name, kms_key_arn

tests/integ/test_marketplace.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sagemaker import AlgorithmEstimator, ModelPackage
2525
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
2626
from sagemaker.utils import sagemaker_timestamp
27+
from sagemaker.utils import _aws_partition
2728
from tests.integ import DATA_DIR
2829
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2930
from tests.integ.marketplace_utils import REGION_ACCOUNT_MAP
@@ -39,12 +40,12 @@
3940
# Both are written by Amazon and are free to subscribe.
4041

4142
ALGORITHM_ARN = (
42-
"arn:aws:sagemaker:%s:%s:algorithm/scikit-decision-trees-"
43+
"arn:{partition}:sagemaker:{region}:{account}:algorithm/scikit-decision-trees-"
4344
"15423055-57b73412d2e93e9239e4e16f83298b8f"
4445
)
4546

4647
MODEL_PACKAGE_ARN = (
47-
"arn:aws:sagemaker:%s:%s:model-package/scikit-iris-detector-"
48+
"arn:{partition}:sagemaker:{region}:{account}:model-package/scikit-iris-detector-"
4849
"154230595-8f00905c1f927a512b73ea29dd09ae30"
4950
)
5051

@@ -63,7 +64,9 @@ def test_marketplace_estimator(sagemaker_session, cpu_instance_type):
6364
data_path = os.path.join(DATA_DIR, "marketplace", "training")
6465
region = sagemaker_session.boto_region_name
6566
account = REGION_ACCOUNT_MAP[region]
66-
algorithm_arn = ALGORITHM_ARN % (region, account)
67+
algorithm_arn = ALGORITHM_ARN.format(
68+
partition=_aws_partition(region), region=region, account=account
69+
)
6770

6871
algo = AlgorithmEstimator(
6972
algorithm_arn=algorithm_arn,
@@ -103,7 +106,9 @@ def test_marketplace_attach(sagemaker_session, cpu_instance_type):
103106
data_path = os.path.join(DATA_DIR, "marketplace", "training")
104107
region = sagemaker_session.boto_region_name
105108
account = REGION_ACCOUNT_MAP[region]
106-
algorithm_arn = ALGORITHM_ARN % (region, account)
109+
algorithm_arn = ALGORITHM_ARN.format(
110+
partition=_aws_partition(region), region=region, account=account
111+
)
107112

108113
mktplace = AlgorithmEstimator(
109114
algorithm_arn=algorithm_arn,
@@ -155,7 +160,9 @@ def test_marketplace_attach(sagemaker_session, cpu_instance_type):
155160
def test_marketplace_model(sagemaker_session, cpu_instance_type):
156161
region = sagemaker_session.boto_region_name
157162
account = REGION_ACCOUNT_MAP[region]
158-
model_package_arn = MODEL_PACKAGE_ARN % (region, account)
163+
model_package_arn = MODEL_PACKAGE_ARN.format(
164+
partition=_aws_partition(region), region=region, account=account
165+
)
159166

160167
def predict_wrapper(endpoint, session):
161168
return sagemaker.RealTimePredictor(
@@ -192,7 +199,9 @@ def test_marketplace_tuning_job(sagemaker_session, cpu_instance_type):
192199
data_path = os.path.join(DATA_DIR, "marketplace", "training")
193200
region = sagemaker_session.boto_region_name
194201
account = REGION_ACCOUNT_MAP[region]
195-
algorithm_arn = ALGORITHM_ARN % (region, account)
202+
algorithm_arn = ALGORITHM_ARN.format(
203+
partition=_aws_partition(region), region=region, account=account
204+
)
196205

197206
mktplace = AlgorithmEstimator(
198207
algorithm_arn=algorithm_arn,
@@ -233,7 +242,9 @@ def test_marketplace_transform_job(sagemaker_session, cpu_instance_type):
233242
data_path = os.path.join(DATA_DIR, "marketplace", "training")
234243
region = sagemaker_session.boto_region_name
235244
account = REGION_ACCOUNT_MAP[region]
236-
algorithm_arn = ALGORITHM_ARN % (region, account)
245+
algorithm_arn = ALGORITHM_ARN.format(
246+
partition=_aws_partition(region), region=region, account=account
247+
)
237248

238249
algo = AlgorithmEstimator(
239250
algorithm_arn=algorithm_arn,
@@ -279,7 +290,9 @@ def test_marketplace_transform_job_from_model_package(sagemaker_session, cpu_ins
279290

280291
region = sagemaker_session.boto_region_name
281292
account = REGION_ACCOUNT_MAP[region]
282-
model_package_arn = MODEL_PACKAGE_ARN % (region, account)
293+
model_package_arn = MODEL_PACKAGE_ARN.format(
294+
partition=_aws_partition(region), region=region, account=account
295+
)
283296

284297
model = ModelPackage(
285298
role="SageMakerRole",

tests/unit/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,11 @@ def test_sts_regional_endpoint():
699699
endpoint = sagemaker.utils.sts_regional_endpoint("us-iso-east-1")
700700
assert endpoint == "https://sts.us-iso-east-1.c2s.ic.gov"
701701
assert botocore.utils.is_valid_endpoint_url(endpoint)
702+
703+
704+
def test_partition_by_region():
705+
assert sagemaker.utils._aws_partition("us-west-2") == "aws"
706+
assert sagemaker.utils._aws_partition("cn-north-1") == "aws-cn"
707+
assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov"
708+
assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso"
709+
assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b"

0 commit comments

Comments
 (0)