Skip to content

Commit 1c3f4b1

Browse files
rohangujarathiRohan Gujarathi
andauthored
fix: Perform integrity checks for remote function execution (aws#3854)
Co-authored-by: Rohan Gujarathi <[email protected]>
1 parent 47354f0 commit 1c3f4b1

File tree

13 files changed

+378
-107
lines changed

13 files changed

+378
-107
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def wrapper(*args, **kwargs):
301301
s3_uri=s3_path_join(
302302
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
303303
),
304+
hmac_key=job.hmac_key,
304305
)
305306
except ServiceError as serr:
306307
chained_e = serr.__cause__
@@ -337,6 +338,7 @@ def wrapper(*args, **kwargs):
337338
return serialization.deserialize_obj_from_s3(
338339
sagemaker_session=job_settings.sagemaker_session,
339340
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
341+
hmac_key=job.hmac_key,
340342
)
341343

342344
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -861,6 +863,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
861863
job_return = serialization.deserialize_obj_from_s3(
862864
sagemaker_session=sagemaker_session,
863865
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
866+
hmac_key=job.hmac_key,
864867
)
865868
except DeserializationError as e:
866869
client_exception = e
@@ -872,6 +875,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
872875
job_exception = serialization.deserialize_exception_from_s3(
873876
sagemaker_session=sagemaker_session,
874877
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
878+
hmac_key=job.hmac_key,
875879
)
876880
except ServiceError as serr:
877881
chained_e = serr.__cause__
@@ -961,6 +965,7 @@ def result(self, timeout: float = None) -> Any:
961965
self._return = serialization.deserialize_obj_from_s3(
962966
sagemaker_session=self._job.sagemaker_session,
963967
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
968+
hmac_key=self._job.hmac_key,
964969
)
965970
self._state = _FINISHED
966971
return self._return
@@ -969,6 +974,7 @@ def result(self, timeout: float = None) -> Any:
969974
self._exception = serialization.deserialize_exception_from_s3(
970975
sagemaker_session=self._job.sagemaker_session,
971976
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
977+
hmac_key=self._job.hmac_key,
972978
)
973979
except ServiceError as serr:
974980
chained_e = serr.__cause__

src/sagemaker/remote_function/core/serialization.py

Lines changed: 120 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
import json
1818
import os
1919
import sys
20+
import hmac
21+
import hashlib
2022

2123
import cloudpickle
2224

2325
from typing import Any, Callable
2426
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
2527
from sagemaker.s3 import S3Downloader, S3Uploader
28+
from sagemaker.session import Session
29+
2630
from tblib import pickling_support
2731

2832

@@ -34,6 +38,7 @@ def _get_python_version():
3438
class _MetaData:
3539
"""Metadata about the serialized data or functions."""
3640

41+
sha256_hash: str
3742
version: str = "2023-04-24"
3843
python_version: str = _get_python_version()
3944
serialization_module: str = "cloudpickle"
@@ -48,11 +53,17 @@ def from_json(s):
4853
except json.decoder.JSONDecodeError:
4954
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")
5055

51-
metadata = _MetaData()
56+
sha256_hash = obj.get("sha256_hash")
57+
metadata = _MetaData(sha256_hash=sha256_hash)
5258
metadata.version = obj.get("version")
5359
metadata.python_version = obj.get("python_version")
5460
metadata.serialization_module = obj.get("serialization_module")
5561

62+
if not sha256_hash:
63+
raise DeserializationError(
64+
"Corrupt metadata file. SHA256 hash for the serialized data does not exist"
65+
)
66+
5667
if not (
5768
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
5869
):
@@ -67,20 +78,16 @@ class CloudpickleSerializer:
6778
"""Serializer using cloudpickle."""
6879

6980
@staticmethod
70-
def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
81+
def serialize(obj: Any) -> Any:
7182
"""Serializes data object and uploads it to S3.
7283
7384
Args:
74-
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75-
calls are delegated to.
76-
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77-
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
7885
obj: object to be serialized and persisted
7986
Raises:
8087
SerializationError: when fail to serialize object to bytes.
8188
"""
8289
try:
83-
bytes_to_upload = cloudpickle.dumps(obj)
90+
return cloudpickle.dumps(obj)
8491
except Exception as e:
8592
if isinstance(
8693
e, NotImplementedError
@@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
96103
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
97104
) from e
98105

99-
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
100-
101106
@staticmethod
102-
def deserialize(sagemaker_session, s3_uri) -> Any:
107+
def deserialize(s3_uri: str, bytes_to_deserialize) -> Any:
103108
"""Downloads from S3 and then deserializes data objects.
104109
105110
Args:
@@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
111116
Raises:
112117
DeserializationError: when fail to serialize object to bytes.
113118
"""
114-
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
115119

116120
try:
117121
return cloudpickle.loads(bytes_to_deserialize)
@@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
122126

123127

124128
# TODO: use dask serializer in case dask distributed is installed in users' environment.
125-
def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None):
129+
def serialize_func_to_s3(
130+
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
131+
):
126132
"""Serializes function and uploads it to S3.
127133
128134
Args:
129135
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
130136
calls are delegated to.
131137
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
138+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
132139
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
133140
func: function to be serialized and persisted
134141
Raises:
135142
SerializationError: when fail to serialize function to bytes.
136143
"""
137144

145+
bytes_to_upload = CloudpickleSerializer.serialize(func)
146+
138147
_upload_bytes_to_s3(
139-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
148+
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
140149
)
141-
CloudpickleSerializer.serialize(
142-
func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
150+
151+
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
152+
153+
_upload_bytes_to_s3(
154+
_MetaData(sha256_hash).to_json(),
155+
os.path.join(s3_uri, "metadata.json"),
156+
s3_kms_key,
157+
sagemaker_session,
143158
)
144159

145160

146-
def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
161+
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
147162
"""Downloads from S3 and then deserializes data objects.
148163
149164
This method downloads the serialized training job outputs to a temporary directory and
@@ -153,61 +168,94 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
153168
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
154169
calls are delegated to.
155170
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
171+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
156172
Returns :
157173
The deserialized function.
158174
Raises:
159175
DeserializationError: when fail to serialize function to bytes.
160176
"""
161-
_MetaData.from_json(
177+
metadata = _MetaData.from_json(
162178
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
163179
)
164180

165-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
181+
bytes_to_deserialize = _read_bytes_from_s3(
182+
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
183+
)
184+
185+
_perform_integrity_check(
186+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
187+
)
188+
189+
return CloudpickleSerializer.deserialize(
190+
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
191+
)
166192

167193

168-
def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
194+
def serialize_obj_to_s3(
195+
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
196+
):
169197
"""Serializes data object and uploads it to S3.
170198
171199
Args:
172200
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
173201
calls are delegated to.
174202
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
175203
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
204+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
176205
obj: object to be serialized and persisted
177206
Raises:
178207
SerializationError: when fail to serialize object to bytes.
179208
"""
180209

210+
bytes_to_upload = CloudpickleSerializer.serialize(obj)
211+
181212
_upload_bytes_to_s3(
182-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
213+
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
183214
)
184-
CloudpickleSerializer.serialize(
185-
obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
215+
216+
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
217+
218+
_upload_bytes_to_s3(
219+
_MetaData(sha256_hash).to_json(),
220+
os.path.join(s3_uri, "metadata.json"),
221+
s3_kms_key,
222+
sagemaker_session,
186223
)
187224

188225

189-
def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
226+
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
190227
"""Downloads from S3 and then deserializes data objects.
191228
192229
Args:
193230
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
194231
calls are delegated to.
195232
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
233+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
196234
Returns :
197235
Deserialized python objects.
198236
Raises:
199237
DeserializationError: when fail to serialize object to bytes.
200238
"""
201239

202-
_MetaData.from_json(
240+
metadata = _MetaData.from_json(
203241
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
204242
)
205243

206-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
244+
bytes_to_deserialize = _read_bytes_from_s3(
245+
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
246+
)
247+
248+
_perform_integrity_check(
249+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
250+
)
251+
252+
return CloudpickleSerializer.deserialize(
253+
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
254+
)
207255

208256

209257
def serialize_exception_to_s3(
210-
exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None
258+
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
211259
):
212260
"""Serializes exception with traceback and uploads it to S3.
213261
@@ -216,37 +264,58 @@ def serialize_exception_to_s3(
216264
calls are delegated to.
217265
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
218266
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
267+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
219268
exc: Exception to be serialized and persisted
220269
Raises:
221270
SerializationError: when fail to serialize object to bytes.
222271
"""
223272
pickling_support.install()
273+
274+
bytes_to_upload = CloudpickleSerializer.serialize(exc)
275+
224276
_upload_bytes_to_s3(
225-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
277+
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
226278
)
227-
CloudpickleSerializer.serialize(
228-
exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
279+
280+
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
281+
282+
_upload_bytes_to_s3(
283+
_MetaData(sha256_hash).to_json(),
284+
os.path.join(s3_uri, "metadata.json"),
285+
s3_kms_key,
286+
sagemaker_session,
229287
)
230288

231289

232-
def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
290+
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
233291
"""Downloads from S3 and then deserializes exception.
234292
235293
Args:
236294
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
237295
calls are delegated to.
238296
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
297+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
239298
Returns :
240299
Deserialized exception with traceback.
241300
Raises:
242301
DeserializationError: when fail to serialize object to bytes.
243302
"""
244303

245-
_MetaData.from_json(
304+
metadata = _MetaData.from_json(
246305
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
247306
)
248307

249-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
308+
bytes_to_deserialize = _read_bytes_from_s3(
309+
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
310+
)
311+
312+
_perform_integrity_check(
313+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
314+
)
315+
316+
return CloudpickleSerializer.deserialize(
317+
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
318+
)
250319

251320

252321
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):
@@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
269338
raise ServiceError(
270339
"Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e))
271340
) from e
341+
342+
343+
def _compute_hash(buffer: bytes, secret_key: str) -> str:
344+
"""Compute the hmac-sha256 hash"""
345+
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
346+
347+
348+
def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
349+
"""Performs integrify checks for serialized code/arguments uploaded to s3.
350+
351+
Verifies whether the hash read from s3 matches the hash calculated
352+
during remote function execution.
353+
"""
354+
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
355+
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
356+
raise DeserializationError(
357+
"Integrity check for the serialized function or data failed. "
358+
"Please restrict access to your S3 bucket"
359+
)

0 commit comments

Comments
 (0)