17
17
import json
18
18
import os
19
19
import sys
20
+ import hmac
21
+ import hashlib
20
22
21
23
import cloudpickle
22
24
23
25
from typing import Any , Callable
24
26
from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
25
27
from sagemaker .s3 import S3Downloader , S3Uploader
28
+ from sagemaker .session import Session
29
+
26
30
from tblib import pickling_support
27
31
28
32
@@ -34,6 +38,7 @@ def _get_python_version():
34
38
class _MetaData :
35
39
"""Metadata about the serialized data or functions."""
36
40
41
+ sha256_hash : str
37
42
version : str = "2023-04-24"
38
43
python_version : str = _get_python_version ()
39
44
serialization_module : str = "cloudpickle"
@@ -48,11 +53,17 @@ def from_json(s):
48
53
except json .decoder .JSONDecodeError :
49
54
raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
50
55
51
- metadata = _MetaData ()
56
+ sha256_hash = obj .get ("sha256_hash" )
57
+ metadata = _MetaData (sha256_hash = sha256_hash )
52
58
metadata .version = obj .get ("version" )
53
59
metadata .python_version = obj .get ("python_version" )
54
60
metadata .serialization_module = obj .get ("serialization_module" )
55
61
62
+ if not sha256_hash :
63
+ raise DeserializationError (
64
+ "Corrupt metadata file. SHA256 hash for the serialized data does not exist"
65
+ )
66
+
56
67
if not (
57
68
metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
58
69
):
@@ -67,20 +78,16 @@ class CloudpickleSerializer:
67
78
"""Serializer using cloudpickle."""
68
79
69
80
@staticmethod
70
- def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ) :
81
+ def serialize (obj : Any ) -> Any :
71
82
"""Serializes data object and uploads it to S3.
72
83
73
84
Args:
74
- sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75
- calls are delegated to.
76
- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77
- s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
78
85
obj: object to be serialized and persisted
79
86
Raises:
80
87
SerializationError: when fail to serialize object to bytes.
81
88
"""
82
89
try :
83
- bytes_to_upload = cloudpickle .dumps (obj )
90
+ return cloudpickle .dumps (obj )
84
91
except Exception as e :
85
92
if isinstance (
86
93
e , NotImplementedError
@@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
96
103
"Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
97
104
) from e
98
105
99
- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
100
-
101
106
@staticmethod
102
- def deserialize (sagemaker_session , s3_uri ) -> Any :
107
+ def deserialize (s3_uri : str , bytes_to_deserialize ) -> Any :
103
108
"""Downloads from S3 and then deserializes data objects.
104
109
105
110
Args:
@@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
111
116
Raises:
112
117
DeserializationError: when fail to serialize object to bytes.
113
118
"""
114
- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
115
119
116
120
try :
117
121
return cloudpickle .loads (bytes_to_deserialize )
@@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
122
126
123
127
124
128
# TODO: use dask serializer in case dask distributed is installed in users' environment.
125
- def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
129
+ def serialize_func_to_s3 (
130
+ func : Callable , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
131
+ ):
126
132
"""Serializes function and uploads it to S3.
127
133
128
134
Args:
129
135
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
130
136
calls are delegated to.
131
137
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
138
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
132
139
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
133
140
func: function to be serialized and persisted
134
141
Raises:
135
142
SerializationError: when fail to serialize function to bytes.
136
143
"""
137
144
145
+ bytes_to_upload = CloudpickleSerializer .serialize (func )
146
+
138
147
_upload_bytes_to_s3 (
139
- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
148
+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
140
149
)
141
- CloudpickleSerializer .serialize (
142
- func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
150
+
151
+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
152
+
153
+ _upload_bytes_to_s3 (
154
+ _MetaData (sha256_hash ).to_json (),
155
+ os .path .join (s3_uri , "metadata.json" ),
156
+ s3_kms_key ,
157
+ sagemaker_session ,
143
158
)
144
159
145
160
146
- def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
161
+ def deserialize_func_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Callable :
147
162
"""Downloads from S3 and then deserializes data objects.
148
163
149
164
This method downloads the serialized training job outputs to a temporary directory and
@@ -153,61 +168,94 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
153
168
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
154
169
calls are delegated to.
155
170
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
171
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
156
172
Returns :
157
173
The deserialized function.
158
174
Raises:
159
175
DeserializationError: when fail to serialize function to bytes.
160
176
"""
161
- _MetaData .from_json (
177
+ metadata = _MetaData .from_json (
162
178
_read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
163
179
)
164
180
165
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
181
+ bytes_to_deserialize = _read_bytes_from_s3 (
182
+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
183
+ )
184
+
185
+ _perform_integrity_check (
186
+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
187
+ )
188
+
189
+ return CloudpickleSerializer .deserialize (
190
+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
191
+ )
166
192
167
193
168
- def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
194
+ def serialize_obj_to_s3 (
195
+ obj : Any , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
196
+ ):
169
197
"""Serializes data object and uploads it to S3.
170
198
171
199
Args:
172
200
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
173
201
calls are delegated to.
174
202
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
175
203
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
204
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
176
205
obj: object to be serialized and persisted
177
206
Raises:
178
207
SerializationError: when fail to serialize object to bytes.
179
208
"""
180
209
210
+ bytes_to_upload = CloudpickleSerializer .serialize (obj )
211
+
181
212
_upload_bytes_to_s3 (
182
- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
213
+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
183
214
)
184
- CloudpickleSerializer .serialize (
185
- obj , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
215
+
216
+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
217
+
218
+ _upload_bytes_to_s3 (
219
+ _MetaData (sha256_hash ).to_json (),
220
+ os .path .join (s3_uri , "metadata.json" ),
221
+ s3_kms_key ,
222
+ sagemaker_session ,
186
223
)
187
224
188
225
189
- def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
226
+ def deserialize_obj_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
190
227
"""Downloads from S3 and then deserializes data objects.
191
228
192
229
Args:
193
230
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
194
231
calls are delegated to.
195
232
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
233
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
196
234
Returns :
197
235
Deserialized python objects.
198
236
Raises:
199
237
DeserializationError: when fail to serialize object to bytes.
200
238
"""
201
239
202
- _MetaData .from_json (
240
+ metadata = _MetaData .from_json (
203
241
_read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
204
242
)
205
243
206
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
244
+ bytes_to_deserialize = _read_bytes_from_s3 (
245
+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
246
+ )
247
+
248
+ _perform_integrity_check (
249
+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
250
+ )
251
+
252
+ return CloudpickleSerializer .deserialize (
253
+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
254
+ )
207
255
208
256
209
257
def serialize_exception_to_s3 (
210
- exc : Exception , sagemaker_session , s3_uri : str , s3_kms_key : str = None
258
+ exc : Exception , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
211
259
):
212
260
"""Serializes exception with traceback and uploads it to S3.
213
261
@@ -216,37 +264,58 @@ def serialize_exception_to_s3(
216
264
calls are delegated to.
217
265
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
218
266
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
267
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
219
268
exc: Exception to be serialized and persisted
220
269
Raises:
221
270
SerializationError: when fail to serialize object to bytes.
222
271
"""
223
272
pickling_support .install ()
273
+
274
+ bytes_to_upload = CloudpickleSerializer .serialize (exc )
275
+
224
276
_upload_bytes_to_s3 (
225
- _MetaData (). to_json () , os .path .join (s3_uri , "metadata.json " ), s3_kms_key , sagemaker_session
277
+ bytes_to_upload , os .path .join (s3_uri , "payload.pkl " ), s3_kms_key , sagemaker_session
226
278
)
227
- CloudpickleSerializer .serialize (
228
- exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
279
+
280
+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
281
+
282
+ _upload_bytes_to_s3 (
283
+ _MetaData (sha256_hash ).to_json (),
284
+ os .path .join (s3_uri , "metadata.json" ),
285
+ s3_kms_key ,
286
+ sagemaker_session ,
229
287
)
230
288
231
289
232
- def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
290
+ def deserialize_exception_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
233
291
"""Downloads from S3 and then deserializes exception.
234
292
235
293
Args:
236
294
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
237
295
calls are delegated to.
238
296
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
297
+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
239
298
Returns :
240
299
Deserialized exception with traceback.
241
300
Raises:
242
301
DeserializationError: when fail to serialize object to bytes.
243
302
"""
244
303
245
- _MetaData .from_json (
304
+ metadata = _MetaData .from_json (
246
305
_read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
247
306
)
248
307
249
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
308
+ bytes_to_deserialize = _read_bytes_from_s3 (
309
+ os .path .join (s3_uri , "payload.pkl" ), sagemaker_session
310
+ )
311
+
312
+ _perform_integrity_check (
313
+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
314
+ )
315
+
316
+ return CloudpickleSerializer .deserialize (
317
+ os .path .join (s3_uri , "payload.pkl" ), bytes_to_deserialize
318
+ )
250
319
251
320
252
321
def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
@@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
269
338
raise ServiceError (
270
339
"Failed to read serialized bytes from {}: {}" .format (s3_uri , repr (e ))
271
340
) from e
341
+
342
+
343
+ def _compute_hash (buffer : bytes , secret_key : str ) -> str :
344
+ """Compute the hmac-sha256 hash"""
345
+ return hmac .new (secret_key .encode (), msg = buffer , digestmod = hashlib .sha256 ).hexdigest ()
346
+
347
+
348
+ def _perform_integrity_check (expected_hash_value : str , secret_key : str , buffer : bytes ):
349
+ """Performs integrify checks for serialized code/arguments uploaded to s3.
350
+
351
+ Verifies whether the hash read from s3 matches the hash calculated
352
+ during remote function execution.
353
+ """
354
+ actual_hash_value = _compute_hash (buffer = buffer , secret_key = secret_key )
355
+ if not hmac .compare_digest (expected_hash_value , actual_hash_value ):
356
+ raise DeserializationError (
357
+ "Integrity check for the serialized function or data failed. "
358
+ "Please restrict access to your S3 bucket"
359
+ )
0 commit comments