Skip to content

Commit 14417bb

Browse files
committed
bug fix for hmac key and remove remote function from train
1 parent 462bed0 commit 14417bb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1269
-4874
lines changed

sagemaker-core/src/sagemaker/core/remote_function/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def wrapper(*args, **kwargs):
366366
s3_uri=s3_path_join(
367367
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
368368
),
369-
hmac_key=job.hmac_key,
369+
370370
)
371371
except ServiceError as serr:
372372
chained_e = serr.__cause__
@@ -403,7 +403,7 @@ def wrapper(*args, **kwargs):
403403
return serialization.deserialize_obj_from_s3(
404404
sagemaker_session=job_settings.sagemaker_session,
405405
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
406-
hmac_key=job.hmac_key,
406+
407407
)
408408

409409
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -983,7 +983,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
983983
job_return = serialization.deserialize_obj_from_s3(
984984
sagemaker_session=sagemaker_session,
985985
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
986-
hmac_key=job.hmac_key,
986+
987987
)
988988
except DeserializationError as e:
989989
client_exception = e
@@ -995,7 +995,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
995995
job_exception = serialization.deserialize_exception_from_s3(
996996
sagemaker_session=sagemaker_session,
997997
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
998-
hmac_key=job.hmac_key,
998+
999999
)
10001000
except ServiceError as serr:
10011001
chained_e = serr.__cause__
@@ -1085,7 +1085,7 @@ def result(self, timeout: float = None) -> Any:
10851085
self._return = serialization.deserialize_obj_from_s3(
10861086
sagemaker_session=self._job.sagemaker_session,
10871087
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1088-
hmac_key=self._job.hmac_key,
1088+
10891089
)
10901090
self._state = _FINISHED
10911091
return self._return
@@ -1094,7 +1094,7 @@ def result(self, timeout: float = None) -> Any:
10941094
self._exception = serialization.deserialize_exception_from_s3(
10951095
sagemaker_session=self._job.sagemaker_session,
10961096
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1097-
hmac_key=self._job.hmac_key,
1097+
10981098
)
10991099
except ServiceError as serr:
11001100
chained_e = serr.__cause__

sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ class _DelayedReturnResolver:
164164
def __init__(
165165
self,
166166
delayed_returns: List[_DelayedReturn],
167-
hmac_key: str,
168167
properties_resolver: _PropertiesResolver,
169168
parameter_resolver: _ParameterResolver,
170169
execution_variable_resolver: _ExecutionVariableResolver,
@@ -175,7 +174,6 @@ def __init__(
175174
176175
Args:
177176
delayed_returns: list of delayed returns to resolve.
178-
hmac_key: key used to encrypt serialized and deserialized function and arguments.
179177
properties_resolver: resolver used to resolve step properties.
180178
parameter_resolver: resolver used to pipeline parameters.
181179
execution_variable_resolver: resolver used to resolve execution variables.
@@ -197,7 +195,6 @@ def deserialization_task(uri):
197195
return uri, deserialize_obj_from_s3(
198196
sagemaker_session=settings["sagemaker_session"],
199197
s3_uri=uri,
200-
hmac_key=hmac_key,
201198
)
202199

203200
with ThreadPoolExecutor() as executor:
@@ -247,7 +244,6 @@ def resolve_pipeline_variables(
247244
context: Context,
248245
func_args: Tuple,
249246
func_kwargs: Dict,
250-
hmac_key: str,
251247
s3_base_uri: str,
252248
**settings,
253249
):
@@ -257,7 +253,6 @@ def resolve_pipeline_variables(
257253
context: context for the execution.
258254
func_args: function args.
259255
func_kwargs: function kwargs.
260-
hmac_key: key used to encrypt serialized and deserialized function and arguments.
261256
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
262257
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
263258
**settings: settings to pass to the deserialization function.
@@ -280,7 +275,6 @@ def resolve_pipeline_variables(
280275
properties_resolver = _PropertiesResolver(context)
281276
delayed_return_resolver = _DelayedReturnResolver(
282277
delayed_returns=delayed_returns,
283-
hmac_key=hmac_key,
284278
properties_resolver=properties_resolver,
285279
parameter_resolver=parameter_resolver,
286280
execution_variable_resolver=execution_variable_resolver,

sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import io
2020

2121
import sys
22-
import hmac
2322
import hashlib
2423
import pickle
2524

@@ -156,15 +155,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
156155

157156
# TODO: use dask serializer in case dask distributed is installed in users' environment.
158157
def serialize_func_to_s3(
159-
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
158+
func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
160159
):
161160
"""Serializes function and uploads it to S3.
162161
163162
Args:
164163
sagemaker_session (sagemaker.core.helper.session.Session):
165164
The underlying Boto3 session which AWS service calls are delegated to.
166165
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
167-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
168166
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
169167
func: function to be serialized and persisted
170168
Raises:
@@ -173,14 +171,13 @@ def serialize_func_to_s3(
173171

174172
_upload_payload_and_metadata_to_s3(
175173
bytes_to_upload=CloudpickleSerializer.serialize(func),
176-
hmac_key=hmac_key,
177174
s3_uri=s3_uri,
178175
sagemaker_session=sagemaker_session,
179176
s3_kms_key=s3_kms_key,
180177
)
181178

182179

183-
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
180+
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
184181
"""Downloads from S3 and then deserializes data objects.
185182
186183
This method downloads the serialized training job outputs to a temporary directory and
@@ -190,7 +187,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
190187
sagemaker_session (sagemaker.core.helper.session.Session):
191188
The underlying sagemaker session which AWS service calls are delegated to.
192189
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
193-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
194190
Returns :
195191
The deserialized function.
196192
Raises:
@@ -203,14 +199,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
203199
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
204200

205201
_perform_integrity_check(
206-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
202+
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
207203
)
208204

209205
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
210206

211207

212208
def serialize_obj_to_s3(
213-
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
209+
obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
214210
):
215211
"""Serializes data object and uploads it to S3.
216212
@@ -219,15 +215,13 @@ def serialize_obj_to_s3(
219215
The underlying Boto3 session which AWS service calls are delegated to.
220216
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
221217
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
222-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
223218
obj: object to be serialized and persisted
224219
Raises:
225220
SerializationError: when fail to serialize object to bytes.
226221
"""
227222

228223
_upload_payload_and_metadata_to_s3(
229224
bytes_to_upload=CloudpickleSerializer.serialize(obj),
230-
hmac_key=hmac_key,
231225
s3_uri=s3_uri,
232226
sagemaker_session=sagemaker_session,
233227
s3_kms_key=s3_kms_key,
@@ -274,14 +268,13 @@ def json_serialize_obj_to_s3(
274268
)
275269

276270

277-
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
271+
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
278272
"""Downloads from S3 and then deserializes data objects.
279273
280274
Args:
281275
sagemaker_session (sagemaker.core.helper.session.Session):
282276
The underlying sagemaker session which AWS service calls are delegated to.
283277
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
284-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
285278
Returns :
286279
Deserialized python objects.
287280
Raises:
@@ -295,14 +288,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
295288
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
296289

297290
_perform_integrity_check(
298-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
291+
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
299292
)
300293

301294
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
302295

303296

304297
def serialize_exception_to_s3(
305-
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
298+
exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
306299
):
307300
"""Serializes exception with traceback and uploads it to S3.
308301
@@ -311,7 +304,6 @@ def serialize_exception_to_s3(
311304
The underlying Boto3 session which AWS service calls are delegated to.
312305
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
313306
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
314-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
315307
exc: Exception to be serialized and persisted
316308
Raises:
317309
SerializationError: when fail to serialize object to bytes.
@@ -320,7 +312,6 @@ def serialize_exception_to_s3(
320312

321313
_upload_payload_and_metadata_to_s3(
322314
bytes_to_upload=CloudpickleSerializer.serialize(exc),
323-
hmac_key=hmac_key,
324315
s3_uri=s3_uri,
325316
sagemaker_session=sagemaker_session,
326317
s3_kms_key=s3_kms_key,
@@ -329,7 +320,6 @@ def serialize_exception_to_s3(
329320

330321
def _upload_payload_and_metadata_to_s3(
331322
bytes_to_upload: Union[bytes, io.BytesIO],
332-
hmac_key: str,
333323
s3_uri: str,
334324
sagemaker_session: Session,
335325
s3_kms_key,
@@ -338,15 +328,14 @@ def _upload_payload_and_metadata_to_s3(
338328
339329
Args:
340330
bytes_to_upload (bytes): Serialized bytes to upload.
341-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
342331
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
343332
sagemaker_session (sagemaker.core.helper.session.Session):
344333
The underlying Boto3 session which AWS service calls are delegated to.
345334
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
346335
"""
347336
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
348337

349-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
338+
sha256_hash = _compute_hash(bytes_to_upload)
350339

351340
_upload_bytes_to_s3(
352341
_MetaData(sha256_hash).to_json(),
@@ -356,14 +345,13 @@ def _upload_payload_and_metadata_to_s3(
356345
)
357346

358347

359-
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
348+
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
360349
"""Downloads from S3 and then deserializes exception.
361350
362351
Args:
363352
sagemaker_session (sagemaker.core.helper.session.Session):
364353
The underlying sagemaker session which AWS service calls are delegated to.
365354
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
366-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
367355
Returns :
368356
Deserialized exception with traceback.
369357
Raises:
@@ -377,7 +365,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
377365
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
378366

379367
_perform_integrity_check(
380-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
368+
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
381369
)
382370

383371
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
@@ -403,19 +391,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
403391
) from e
404392

405393

406-
def _compute_hash(buffer: bytes, secret_key: str) -> str:
407-
"""Compute the hmac-sha256 hash"""
408-
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
394+
def _compute_hash(buffer: bytes) -> str:
395+
"""Compute the sha256 hash"""
396+
return hashlib.sha256(buffer).hexdigest()
409397

410398

411-
def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
399+
def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
412400
"""Performs integrity checks for serialized code/arguments uploaded to s3.
413401
414402
Verifies whether the hash read from s3 matches the hash calculated
415403
during remote function execution.
416404
"""
417-
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
418-
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
405+
actual_hash_value = _compute_hash(buffer=buffer)
406+
if expected_hash_value != actual_hash_value:
419407
raise DeserializationError(
420408
"Integrity check for the serialized function or data failed. "
421409
"Please restrict access to your S3 bucket"

0 commit comments

Comments
 (0)