1818from sagemaker import utils
1919
2020PRINCIPAL_TEMPLATE = (
21- '["{account_id}", "{role_arn}", ' '"arn:aws :iam::{account_id}:role/{sagemaker_role}"] '
21+ '["{account_id}", "{role_arn}", ' '"arn:{partition} :iam::{account_id}:role/{sagemaker_role}"] '
2222)
2323
2424KEY_ALIAS = "SageMakerTestKMSKey"
@@ -60,11 +60,14 @@ def _get_kms_key_id(kms_client, alias):
6060
6161
6262def _create_kms_key (
63- kms_client , account_id , role_arn = None , sagemaker_role = "SageMakerRole" , alias = KEY_ALIAS
63+ kms_client , account_id , region , role_arn = None , sagemaker_role = "SageMakerRole" , alias = KEY_ALIAS
6464):
6565 if role_arn :
6666 principal = PRINCIPAL_TEMPLATE .format (
67- account_id = account_id , role_arn = role_arn , sagemaker_role = sagemaker_role
67+ partition = utils ._aws_partition (region ),
68+ account_id = account_id ,
69+ role_arn = role_arn ,
70+ sagemaker_role = sagemaker_role ,
6871 )
6972 else :
7073 principal = '"{account_id}"' .format (account_id = account_id )
@@ -83,7 +86,7 @@ def _create_kms_key(
8386
8487
8588def _add_role_to_policy (
86- kms_client , account_id , role_arn , alias = KEY_ALIAS , sagemaker_role = "SageMakerRole"
89+ kms_client , account_id , role_arn , region , alias = KEY_ALIAS , sagemaker_role = "SageMakerRole"
8790):
8891 key_id = _get_kms_key_id (kms_client , alias )
8992 policy = kms_client .get_key_policy (KeyId = key_id , PolicyName = POLICY_NAME )
@@ -92,7 +95,10 @@ def _add_role_to_policy(
9295
9396 if role_arn not in principal or sagemaker_role not in principal :
9497 principal = PRINCIPAL_TEMPLATE .format (
95- account_id = account_id , role_arn = role_arn , sagemaker_role = sagemaker_role
98+ partition = utils ._aws_partition (region ),
99+ account_id = account_id ,
100+ role_arn = role_arn ,
101+ sagemaker_role = sagemaker_role ,
96102 )
97103
98104 kms_client .put_key_policy (
@@ -115,44 +121,44 @@ def get_or_create_kms_key(
115121 account_id = sts_client .get_caller_identity ()["Account" ]
116122
117123 if kms_key_arn is None :
118- return _create_kms_key (kms_client , account_id , role_arn , sagemaker_role , alias )
124+ return _create_kms_key (kms_client , account_id , region , role_arn , sagemaker_role , alias )
119125
120126 if role_arn :
121- _add_role_to_policy (kms_client , account_id , role_arn , alias , sagemaker_role )
127+ _add_role_to_policy (kms_client , account_id , role_arn , region , alias , sagemaker_role )
122128
123129 return kms_key_arn
124130
125131
126- KMS_BUCKET_POLICY = """{
132+ KMS_BUCKET_POLICY = """{{
127133 "Version": "2012-10-17",
128134 "Id": "PutObjPolicy",
129135 "Statement": [
130- {
136+ {{
131137 "Sid": "DenyIncorrectEncryptionHeader",
132138 "Effect": "Deny",
133139 "Principal": "*",
134140 "Action": "s3:PutObject",
135- "Resource": "arn:aws :s3:::%s /*",
136- "Condition": {
137- "StringNotEquals": {
138- "s3:x-amz-server-side-encryption": "aws :kms"
139- }
140- }
141- },
142- {
141+ "Resource": "arn:{partition} :s3:::{bucket_name} /*",
142+ "Condition": {{
143+ "StringNotEquals": {{
144+ "s3:x-amz-server-side-encryption": "{partition} :kms"
145+ }}
146+ }}
147+ }} ,
148+ {{
143149 "Sid": "DenyUnEncryptedObjectUploads",
144150 "Effect": "Deny",
145151 "Principal": "*",
146152 "Action": "s3:PutObject",
147- "Resource": "arn:aws :s3:::%s /*",
148- "Condition": {
149- "Null": {
153+ "Resource": "arn:{partition} :s3:::{bucket_name} /*",
154+ "Condition": {{
155+ "Null": {{
150156 "s3:x-amz-server-side-encryption": "true"
151- }
152- }
153- }
157+ }}
158+ }}
159+ }}
154160 ]
155- }"""
161+ }} """
156162
157163
158164@contextlib .contextmanager
@@ -167,7 +173,7 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
167173 role_arn = sts_client .get_caller_identity ()["Arn" ]
168174
169175 kms_client = boto_session .client ("kms" )
170- kms_key_arn = _create_kms_key (kms_client , account , role_arn , sagemaker_role , None )
176+ kms_key_arn = _create_kms_key (kms_client , account , region , role_arn , sagemaker_role , None )
171177
172178 region = boto_session .region_name
173179 bucket_name = "sagemaker-{}-{}-with-kms" .format (region , account )
@@ -181,7 +187,9 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
181187 "Rules" : [
182188 {
183189 "ApplyServerSideEncryptionByDefault" : {
184- "SSEAlgorithm" : "aws:kms" ,
190+ "SSEAlgorithm" : "{partition}:kms" .format (
191+ partition = utils ._aws_partition (region )
192+ ),
185193 "KMSMasterKeyID" : kms_key_arn ,
186194 }
187195 }
@@ -190,7 +198,10 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
190198 )
191199
192200 s3_client .put_bucket_policy (
193- Bucket = bucket_name , Policy = KMS_BUCKET_POLICY % (bucket_name , bucket_name )
201+ Bucket = bucket_name ,
202+ Policy = KMS_BUCKET_POLICY .format (
203+ partition = utils ._aws_partition (region ), bucket_name = bucket_name
204+ ),
194205 )
195206
196207 yield "s3://" + bucket_name , kms_key_arn
0 commit comments