1717import json
1818import os
1919import sys
20+ import hmac
21+ import hashlib
2022
2123import cloudpickle
2224
2325from typing import Any , Callable
2426from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
2527from sagemaker .s3 import S3Downloader , S3Uploader
28+ from sagemaker .session import Session
29+
2630from tblib import pickling_support
2731
2832
@@ -34,6 +38,7 @@ def _get_python_version():
3438class _MetaData :
3539 """Metadata about the serialized data or functions."""
3640
41+ sha256_hash : str
3742 version : str = "2023-04-24"
3843 python_version : str = _get_python_version ()
3944 serialization_module : str = "cloudpickle"
@@ -48,11 +53,17 @@ def from_json(s):
4853 except json .decoder .JSONDecodeError :
4954 raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
5055
51- metadata = _MetaData ()
56+ sha256_hash = obj .get ("sha256_hash" )
57+ metadata = _MetaData (sha256_hash = sha256_hash )
5258 metadata .version = obj .get ("version" )
5359 metadata .python_version = obj .get ("python_version" )
5460 metadata .serialization_module = obj .get ("serialization_module" )
5561
62+ if not sha256_hash :
63+ raise DeserializationError (
64+ "Corrupt metadata file. SHA256 hash for the serialized data does not exist"
65+ )
66+
5667 if not (
5768 metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
5869 ):
@@ -67,20 +78,16 @@ class CloudpickleSerializer:
6778 """Serializer using cloudpickle."""
6879
6980 @staticmethod
70- def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ) :
81+ def serialize (obj : Any ) -> Any :
7182 """Serializes data object and uploads it to S3.
7283
7384 Args:
74- sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75- calls are delegated to.
76- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77- s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
7885 obj: object to be serialized and persisted
7986 Raises:
8087 SerializationError: when fail to serialize object to bytes.
8188 """
8289 try :
83- bytes_to_upload = cloudpickle .dumps (obj )
90+ return cloudpickle .dumps (obj )
8491 except Exception as e :
8592 if isinstance (
8693 e , NotImplementedError
@@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
96103 "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
97104 ) from e
98105
99- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
100-
101106 @staticmethod
102- def deserialize (sagemaker_session , s3_uri ) -> Any :
107+ def deserialize (s3_uri : str , bytes_to_deserialize ) -> Any :
103108 """Downloads from S3 and then deserializes data objects.
104109
105110 Args:
@@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
111116 Raises:
112117 DeserializationError: when fail to serialize object to bytes.
113118 """
114- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
115119
116120 try :
117121 return cloudpickle .loads (bytes_to_deserialize )
@@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
122126
123127
124128# TODO: use dask serializer in case dask distributed is installed in users' environment.
125- def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
129+ def serialize_func_to_s3 (
130+ func : Callable , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
131+ ):
126132 """Serializes function and uploads it to S3.
127133
128134 Args:
129135 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
130136 calls are delegated to.
131137 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
138+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
132139 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
133140 func: function to be serialized and persisted
134141 Raises:
135142 SerializationError: when fail to serialize function to bytes.
136143 """
137144
145+ bytes_to_upload = CloudpickleSerializer .serialize (func )
146+
138147 _upload_bytes_to_s3 (
139- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
148+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
140149 )
141- CloudpickleSerializer .serialize (
142- func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
150+
151+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
152+
153+ _upload_bytes_to_s3 (
154+ _MetaData (sha256_hash ).to_json (),
155+ os .path .join (s3_uri , "metadata.json" ),
156+ s3_kms_key ,
157+ sagemaker_session ,
143158 )
144159
145160
146- def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
161+ def deserialize_func_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Callable :
147162 """Downloads from S3 and then deserializes data objects.
148163
149164 This method downloads the serialized training job outputs to a temporary directory and
@@ -153,61 +168,94 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
153168 sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
154169 calls are delegated to.
155170 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
171+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
156172 Returns :
157173 The deserialized function.
158174 Raises:
159175 DeserializationError: when fail to serialize function to bytes.
160176 """
161- _MetaData .from_json (
177+ metadata = _MetaData .from_json (
162178 _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
163179 )
164180
165- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
181+ bytes_to_deserialize = _read_bytes_from_s3 (
182+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
183+ )
184+
185+ _perform_integrity_check (
186+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
187+ )
188+
189+ return CloudpickleSerializer .deserialize (
190+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
191+ )
166192
167193
168- def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
194+ def serialize_obj_to_s3 (
195+ obj : Any , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
196+ ):
169197 """Serializes data object and uploads it to S3.
170198
171199 Args:
172200 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
173201 calls are delegated to.
174202 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
175203 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
204+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
176205 obj: object to be serialized and persisted
177206 Raises:
178207 SerializationError: when fail to serialize object to bytes.
179208 """
180209
210+ bytes_to_upload = CloudpickleSerializer .serialize (obj )
211+
181212 _upload_bytes_to_s3 (
182- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
213+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
183214 )
184- CloudpickleSerializer .serialize (
185- obj , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
215+
216+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
217+
218+ _upload_bytes_to_s3 (
219+ _MetaData (sha256_hash ).to_json (),
220+ os .path .join (s3_uri , "metadata.json" ),
221+ s3_kms_key ,
222+ sagemaker_session ,
186223 )
187224
188225
189- def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
226+ def deserialize_obj_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
190227 """Downloads from S3 and then deserializes data objects.
191228
192229 Args:
193230 sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
194231 calls are delegated to.
195232 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
233+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
196234 Returns :
197235 Deserialized python objects.
198236 Raises:
199237 DeserializationError: when fail to serialize object to bytes.
200238 """
201239
202- _MetaData .from_json (
240+ metadata = _MetaData .from_json (
203241 _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
204242 )
205243
206- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
244+ bytes_to_deserialize = _read_bytes_from_s3 (
245+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
246+ )
247+
248+ _perform_integrity_check (
249+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
250+ )
251+
252+ return CloudpickleSerializer .deserialize (
253+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
254+ )
207255
208256
209257def serialize_exception_to_s3 (
210- exc : Exception , sagemaker_session , s3_uri : str , s3_kms_key : str = None
258+ exc : Exception , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
211259):
212260 """Serializes exception with traceback and uploads it to S3.
213261
@@ -216,37 +264,58 @@ def serialize_exception_to_s3(
216264 calls are delegated to.
217265 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
218266 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
267+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
219268 exc: Exception to be serialized and persisted
220269 Raises:
221270 SerializationError: when fail to serialize object to bytes.
222271 """
223272 pickling_support .install ()
273+
274+ bytes_to_upload = CloudpickleSerializer .serialize (exc )
275+
224276 _upload_bytes_to_s3 (
225- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
277+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
226278 )
227- CloudpickleSerializer .serialize (
228- exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
279+
280+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
281+
282+ _upload_bytes_to_s3 (
283+ _MetaData (sha256_hash ).to_json (),
284+ os .path .join (s3_uri , "metadata.json" ),
285+ s3_kms_key ,
286+ sagemaker_session ,
229287 )
230288
231289
232- def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
290+ def deserialize_exception_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
233291 """Downloads from S3 and then deserializes exception.
234292
235293 Args:
236294 sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
237295 calls are delegated to.
238296 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
297+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
239298 Returns :
240299 Deserialized exception with traceback.
241300 Raises:
242301 DeserializationError: when fail to serialize object to bytes.
243302 """
244303
245- _MetaData .from_json (
304+ metadata = _MetaData .from_json (
246305 _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
247306 )
248307
249- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
308+ bytes_to_deserialize = _read_bytes_from_s3 (
309+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
310+ )
311+
312+ _perform_integrity_check (
313+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
314+ )
315+
316+ return CloudpickleSerializer .deserialize (
317+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
318+ )
250319
251320
252321def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
@@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
269338 raise ServiceError (
270339 "Failed to read serialized bytes from {}: {}" .format (s3_uri , repr (e ))
271340 ) from e
341+
342+
343+ def _compute_hash (buffer : bytes , secret_key : str ) -> str :
344+ """Compute the hmac-sha256 hash"""
345+ return hmac .new (secret_key .encode (), msg = buffer , digestmod = hashlib .sha256 ).hexdigest ()
346+
347+
348+ def _perform_integrity_check (expected_hash_value : str , secret_key : str , buffer : bytes ):
349+ """Performs integrify checks for serialized code/arguments uploaded to s3.
350+
351+ Verifies whether the hash read from s3 matches the hash calculated
352+ during remote function execution.
353+ """
354+ actual_hash_value = _compute_hash (buffer = buffer , secret_key = secret_key )
355+ if not hmac .compare_digest (expected_hash_value , actual_hash_value ):
356+ raise DeserializationError (
357+ "Integrity check for the serialized function or data failed. "
358+ "Please restrict access to your S3 bucket"
359+ )
0 commit comments