diff --git a/awscli/botocore/httpchecksum.py b/awscli/botocore/httpchecksum.py index 02fb105384c6..7d52507ce3b7 100644 --- a/awscli/botocore/httpchecksum.py +++ b/awscli/botocore/httpchecksum.py @@ -75,6 +75,10 @@ def update(self, chunk): def digest(self): return self._int_crc32.to_bytes(4, byteorder="big") + @property + def int_crc(self): + return self._int_crc32 + class CrtCrc32Checksum(BaseChecksum): # Note: This class is only used if the CRT is available @@ -88,6 +92,10 @@ def update(self, chunk): def digest(self): return self._int_crc32.to_bytes(4, byteorder="big") + @property + def int_crc(self): + return self._int_crc32 + class CrtCrc32cChecksum(BaseChecksum): # Note: This class is only used if the CRT is available @@ -101,6 +109,10 @@ def update(self, chunk): def digest(self): return self._int_crc32c.to_bytes(4, byteorder="big") + @property + def int_crc(self): + return self._int_crc32 + class CrtCrc64NvmeChecksum(BaseChecksum): # Note: This class is only used if the CRT is available @@ -114,6 +126,10 @@ def update(self, chunk): def digest(self): return self._int_crc64nvme.to_bytes(8, byteorder="big") + @property + def int_crc(self): + return self._int_crc32 + class Sha1Checksum(BaseChecksum): def __init__(self): @@ -225,6 +241,10 @@ def __init__(self, raw_stream, content_length, checksum, expected): self._checksum = checksum self._expected = expected + @property + def checksum(self): + return self._checksum + def read(self, amt=None): chunk = super().read(amt=amt) self._checksum.update(chunk) diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index d99cdcb31bf3..5ac63556f37e 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -389,6 +389,7 @@ def _list_single_object(self, s3_path): try: params = {'Bucket': bucket, 'Key': key} params.update(self.request_parameters.get('HeadObject', {})) + params["ChecksumMode"] = "ENABLED" response = self._client.head_object(**params) except ClientError as e: # We want to try to give a more helpful error message. diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index 23176f30f889..3084df3f0fdd 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -37,6 +37,7 @@ DeleteSourceFileSubscriber, DeleteSourceObjectSubscriber, DirectoryCreatorSubscriber, + ProvideChecksumSubscriber, ProvideETagSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, @@ -421,6 +422,9 @@ def _add_additional_subscribers(self, subscribers, fileinfo): subscribers.append( DeleteSourceObjectSubscriber(fileinfo.source_client) ) + subscribers.append( + ProvideChecksumSubscriber(fileinfo.associated_response_data) + ) def _submit_transfer_request(self, fileinfo, extra_args, subscribers): bucket, key = find_bucket_key(fileinfo.src) diff --git a/awscli/customizations/s3/subscribers.py b/awscli/customizations/s3/subscribers.py index 68dc05b65a97..7ddcc2d3add5 100644 --- a/awscli/customizations/s3/subscribers.py +++ b/awscli/customizations/s3/subscribers.py @@ -16,6 +16,7 @@ import time from botocore.utils import percent_encode_sequence +from s3transfer.checksums import provide_checksum_to_meta from s3transfer.subscribers import BaseSubscriber from awscli.customizations.s3 import utils @@ -99,6 +100,26 @@ def on_queued(self, future, **kwargs): ) +class ProvideChecksumSubscriber(BaseSubscriber): + """ + A subscriber which provides the object stored checksum and algorithm. + """ + + def __init__(self, response_data): + self.response_data = response_data + + def on_queued(self, future, **kwargs): + if hasattr(future.meta, 'provide_stored_checksum') and hasattr( + future.meta, 'provide_checksum_algorithm' + ): + provide_checksum_to_meta(self.response_data, future.meta) + else: + LOGGER.debug( + f"Not providing stored checksum. Future: {future} does not " + "offer the capability to notify the checksum of an object", + ) + + class DeleteSourceSubscriber(OnDoneFilteredSubscriber): """A subscriber which deletes the source of the transfer.""" diff --git a/awscli/s3transfer/checksums.py b/awscli/s3transfer/checksums.py new file mode 100644 index 000000000000..53220c66c080 --- /dev/null +++ b/awscli/s3transfer/checksums.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +""" +NOTE: All classes and functions in this module are considered private and are +subject to abrupt breaking changes. Please do not use them directly. +""" + +import base64 +import logging +from copy import copy +from functools import cached_property + +from botocore.httpchecksum import CrtCrc32Checksum +from s3transfer.exceptions import S3ValidationError + +logger = logging.getLogger(__name__) + + +class PartStreamingChecksumBody: + def __init__(self, stream, starting_index, full_object_checksum): + self._stream = stream + self._starting_index = starting_index + self._checksum = CRC_CHECKSUM_CLS[ + full_object_checksum.checksum_algorithm + ]() + self._full_object_checksum = full_object_checksum + # If the underlying stream already has a checksum object + # it's updating (eg `botocore.httpchecksum.StreamingChecksumBody`), + # reuse its calculated value. + self._reuse_checksum = hasattr(self._stream, 'checksum') + + @property + def checksum(self): + return self._checksum + + def read(self, *args, **kwargs): + value = self._stream.read(*args, **kwargs) + if not self._reuse_checksum: + self._checksum.update(value) + if not value: + self._set_part_checksum() + return value + + def _set_part_checksum(self): + if not self._reuse_checksum: + value = self._checksum.int_crc + else: + value = self._stream.checksum.int_crc + self._full_object_checksum.set_part_checksum( + self._starting_index, + value, + ) + + +class FullObjectChecksum: + def __init__(self, checksum_algorithm, content_length): + self.checksum_algorithm = checksum_algorithm + self._content_length = content_length + self._combine_function = _CRC_CHECKSUM_TO_COMBINE_FUNCTION[ + self.checksum_algorithm + ] + self._stored_checksum = None + self._part_checksums = None + self._calculated_checksum = None + + @cached_property + def calculated_checksum(self): + if self._calculated_checksum is None: + self._combine_part_checksums() + return self._calculated_checksum + + def set_stored_checksum(self, stored_checksum): + self._stored_checksum = stored_checksum + + def set_part_checksum(self, offset, checksum): + if self._part_checksums is None: + self._part_checksums = {} + self._part_checksums[offset] = checksum + + def _combine_part_checksums(self): + if self._part_checksums is None: + return + + sorted_offsets = sorted(self._part_checksums.keys()) + # Initialize the combined checksum to the first part's checksum value. + combined = self._part_checksums[sorted_offsets[0]] + # To calculate part length, take the current offset and subtract from + # the next offset. If the current offset is the start of the last part, + # then subtract from the total content length. eg, + # (8388608, 16777216) + # (16777216, self._content_length) + remaining_offsets = sorted_offsets[1:] + next_offsets = sorted_offsets[2:] + [self._content_length] + for offset, next_offset in zip(remaining_offsets, next_offsets): + part_checksum = self._part_checksums[offset] + offset_len = next_offset - offset + combined = self._combine_function( + combined, part_checksum, offset_len + ) + + self._calculated_checksum = base64.b64encode( + combined.to_bytes(4, byteorder='big') + ).decode('ascii') + + def validate(self): + if self.calculated_checksum != self._stored_checksum: + raise S3ValidationError( + f"Calculated checksum {self.calculated_checksum} does not match " + f"stored checksum {self._stored_checksum}" + ) + logger.debug( + f"Successfully validated stored checksum {self._stored_checksum} " + f"against calculated checksum {self.calculated_checksum}" + ) + + +def provide_checksum_to_meta(response, transfer_meta): + stored_checksum = None + checksum_algorithm = None + checksum_type = response.get("ChecksumType") + if checksum_type and checksum_type == "FULL_OBJECT": + for crc_checksum in CRC_CHECKSUMS: + if checksum_value := response.get(crc_checksum): + stored_checksum = checksum_value + checksum_algorithm = crc_checksum + break + transfer_meta.provide_checksum_algorithm(checksum_algorithm) + transfer_meta.provide_stored_checksum(stored_checksum) + + +def combine_crc32(crc1, crc2, len2): + """Combine two CRC32 values. + + :type crc1: int + :param crc1: Current CRC32 integer value. + + :type crc2: int + :param crc2: Second CRC32 integer value to combine. + + :type len2: int + :param len2: Length of data that produced `crc2`. + + :rtype: int + :returns: Combined CRC32 integer value. + """ + _GF2_DIM = 32 + _CRC32_POLY = 0xEDB88320 + _MASK_32BIT = 0xFFFFFFFF + + def _gf2_matrix_times(mat, vec): + res = 0 + idx = 0 + while vec != 0: + if vec & 1: + res ^= mat[idx] + vec >>= 1 + idx += 1 + return res + + def _gf2_matrix_square(square, mat): + res = copy(square) + for n in range(_GF2_DIM): + d = mat[n] + res[n] = _gf2_matrix_times(mat, d) + return res + + even = [0] * _GF2_DIM + odd = [0] * _GF2_DIM + + if len2 <= 0: + return crc1 + + odd[0] = _CRC32_POLY + row = 1 + for i in range(1, _GF2_DIM): + odd[i] = row + row <<= 1 + + even = _gf2_matrix_square(even, odd) + odd = _gf2_matrix_square(odd, even) + + while True: + even = _gf2_matrix_square(even, odd) + if len2 & 1: + crc1 = _gf2_matrix_times(even, crc1) + len2 >>= 1 + + if len2 == 0: + break + + odd = _gf2_matrix_square(odd, even) + if len2 & 1: + crc1 = _gf2_matrix_times(odd, crc1) + len2 >>= 1 + + if len2 == 0: + break + + return (crc1 ^ crc2) & _MASK_32BIT + + +_CRC_CHECKSUM_TO_COMBINE_FUNCTION = { + "ChecksumCRC32": combine_crc32, +} + + +CRC_CHECKSUM_CLS = { + "ChecksumCRC32": CrtCrc32Checksum, +} + + +CRC_CHECKSUMS = _CRC_CHECKSUM_TO_COMBINE_FUNCTION.keys() diff --git a/awscli/s3transfer/download.py b/awscli/s3transfer/download.py index 9307e48fa551..fcc15cf5cd9a 100644 --- a/awscli/s3transfer/download.py +++ b/awscli/s3transfer/download.py @@ -15,6 +15,11 @@ import threading from botocore.exceptions import ClientError +from s3transfer.checksums import ( + FullObjectChecksum, + PartStreamingChecksumBody, + provide_checksum_to_meta, +) from s3transfer.compat import seekable from s3transfer.exceptions import ( RetriesExceededError, @@ -150,6 +155,14 @@ def _get_fileobj_from_filename(self, filename): self._transfer_coordinator.add_failure_cleanup(f.close) return f + def get_validate_checksum_task(self, full_object_checksum): + return ValidateChecksumTask( + transfer_coordinator=self._transfer_coordinator, + main_kwargs={ + 'full_object_checksum': full_object_checksum, + }, + ) + class DownloadFilenameOutputManager(DownloadOutputManager): def __init__(self, osutil, transfer_coordinator, io_executor): @@ -354,10 +367,12 @@ def _submit( if ( transfer_future.meta.size is None or transfer_future.meta.etag is None + or not transfer_future.meta.checksum_is_provided ): response = client.head_object( Bucket=transfer_future.meta.call_args.bucket, Key=transfer_future.meta.call_args.key, + ChecksumMode="ENABLED", **transfer_future.meta.call_args.extra_args, ) # If a size was not provided figure out the size for the @@ -368,6 +383,8 @@ def _submit( # Provide an etag to ensure a stored object is not modified # during a multipart download. transfer_future.meta.provide_object_etag(response.get('ETag')) + # Provide stored checksum value and algorithm. + provide_checksum_to_meta(response, transfer_future.meta) download_output_manager = self._get_download_output_manager_cls( transfer_future, osutil @@ -484,6 +501,27 @@ def _submit_ranged_download_request( download_output_manager, io_executor ) ) + + full_object_checksum = None + if ( + transfer_future.meta.checksum_is_provided + and transfer_future.meta.stored_checksum + ): + full_object_checksum = FullObjectChecksum( + transfer_future.meta.checksum_algorithm, + transfer_future.meta.size, + ) + full_object_checksum.set_stored_checksum( + transfer_future.meta.stored_checksum, + ) + validate_checksum_invoker = CountCallbackInvoker( + self._get_validate_checksum_task( + download_output_manager, + io_executor, + full_object_checksum, + ) + ) + for i in range(num_parts): # Calculate the range parameter range_parameter = calculate_range_parameter( @@ -498,6 +536,7 @@ def _submit_ranged_download_request( extra_args['IfMatch'] = transfer_future.meta.etag extra_args.update(call_args.extra_args) finalize_download_invoker.increment() + validate_checksum_invoker.increment() # Submit the ranged downloads self._transfer_coordinator.submit( request_executor, @@ -515,13 +554,38 @@ def _submit_ranged_download_request( 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, 'bandwidth_limiter': bandwidth_limiter, + 'full_object_checksum': full_object_checksum, }, - done_callbacks=[finalize_download_invoker.decrement], + done_callbacks=[ + validate_checksum_invoker.decrement, + finalize_download_invoker.decrement, + ], ), tag=get_object_tag, ) + + validate_checksum_invoker.finalize() finalize_download_invoker.finalize() + def _get_validate_checksum_task( + self, + download_manager, + io_executor, + full_object_checksum, + ): + if full_object_checksum is None: + task = CompleteDownloadNOOPTask( + transfer_coordinator=self._transfer_coordinator, + is_final=False, + ) + else: + task = download_manager.get_validate_checksum_task( + full_object_checksum, + ) + return FunctionContainer( + self._transfer_coordinator.submit, io_executor, task + ) + def _get_final_io_task_submission_callback( self, download_manager, io_executor ): @@ -555,6 +619,7 @@ def _main( io_chunksize, start_index=0, bandwidth_limiter=None, + full_object_checksum=None, ): """Downloads an object and places content into io queue @@ -585,8 +650,15 @@ def _main( extra_args.get('Range'), response.get('ContentRange'), ) + streaming_body = response['Body'] + if full_object_checksum: + streaming_body = PartStreamingChecksumBody( + streaming_body, + start_index, + full_object_checksum, + ) streaming_body = StreamReaderProgress( - response['Body'], callbacks + streaming_body, callbacks ) if bandwidth_limiter: streaming_body = ( @@ -860,3 +932,8 @@ def request_writes(self, offset, data): del self._pending_offsets[next_write_offset] self._next_offset += len(next_write) return writes + + +class ValidateChecksumTask(Task): + def _main(self, full_object_checksum): + full_object_checksum.validate() diff --git a/awscli/s3transfer/futures.py b/awscli/s3transfer/futures.py index 6222a42baba8..b6a86c62b53e 100644 --- a/awscli/s3transfer/futures.py +++ b/awscli/s3transfer/futures.py @@ -122,6 +122,8 @@ def set_exception(self, exception): class TransferMeta(BaseTransferMeta): """Holds metadata about the TransferFuture""" + _UNPROVIDED = object() + def __init__(self, call_args=None, transfer_id=None): self._call_args = call_args self._transfer_id = transfer_id @@ -129,6 +131,14 @@ def __init__(self, call_args=None, transfer_id=None): self._user_context = {} self._etag = None + # These values are provided via initial HeadObject requests + # when downloading objects. But they're not guaranteed to be + # in the response, in which case `None` values will be provided. + # A sentinel value is set as the default so we can disambiguate + # between "no value provided yet" and "explicit `None` value provided". + self._stored_checksum = self._UNPROVIDED + self._checksum_algorithm = self._UNPROVIDED + @property def call_args(self): """The call args used in the transfer request""" @@ -154,6 +164,28 @@ def etag(self): """The etag of the stored object for validating multipart downloads""" return self._etag + @property + def stored_checksum(self): + """Stored full object checksum value, if any""" + if self._stored_checksum is self._UNPROVIDED: + return None + return self._stored_checksum + + @property + def checksum_algorithm(self): + """Algorithm used to compute stored full object checksum, if any""" + if self._checksum_algorithm is self._UNPROVIDED: + return None + return self._checksum_algorithm + + @property + def checksum_is_provided(self): + """Boolean used to check if checksum properties have been provided""" + return ( + self._stored_checksum is not self._UNPROVIDED + and self._checksum_algorithm is not self._UNPROVIDED + ) + def provide_transfer_size(self, size): """A method to provide the size of a transfer request @@ -172,6 +204,24 @@ def provide_object_etag(self, etag): """ self._etag = etag + def provide_stored_checksum(self, stored_checksum): + """A method to provide the stored checksum of a transfer request + + By providing this value with `checksum_algorithm`, the TransferManager + will validate multipart downloads by calculating the + full object checksum and comparing it to the stored checksum. + """ + self._stored_checksum = stored_checksum + + def provide_checksum_algorithm(self, checksum_algorithm): + """A method to provide the checksum algorithm of a transfer request + + By providing this value with `stored_checksum`, the TransferManager + will validate multipart downloads by calculating the + full object checksum and comparing it to the stored checksum. + """ + self._checksum_algorithm = checksum_algorithm + class TransferCoordinator: """A helper class for managing TransferFuture""" diff --git a/tests/__init__.py b/tests/__init__.py index 6e7f26fc8157..5feac09700d2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -93,6 +93,7 @@ RecordingSubscriber, FileSizeProvider, ETagProvider, + ChecksumProvider, RecordingOSUtils, RecordingExecutor, TransferCoordinatorWithInterrupt, diff --git a/tests/functional/s3/test_cp_command.py b/tests/functional/s3/test_cp_command.py index 5494d0a8a7d9..ef09d9112d6b 100644 --- a/tests/functional/s3/test_cp_command.py +++ b/tests/functional/s3/test_cp_command.py @@ -308,6 +308,7 @@ def test_dryrun_download(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -349,6 +350,7 @@ def test_dryrun_copy(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -1305,7 +1307,10 @@ def test_single_download(self): self.assert_operations_called( [ self.head_object_request( - 'mybucket', 'mykey', RequestPayer='requester' + 'mybucket', + 'mykey', + RequestPayer='requester', + **{'ChecksumMode': "ENABLED"}, ), self.get_object_request( 'mybucket', 'mykey', RequestPayer='requester' @@ -1328,7 +1333,10 @@ def test_ranged_download(self): self.assert_operations_called( [ self.head_object_request( - 'mybucket', 'mykey', RequestPayer='requester' + 'mybucket', + 'mykey', + RequestPayer='requester', + **{'ChecksumMode': "ENABLED"}, ), self.get_object_request( 'mybucket', @@ -1380,7 +1388,10 @@ def test_single_copy(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + **{'ChecksumMode': "ENABLED"}, ), self.copy_object_request( 'sourcebucket', @@ -1409,7 +1420,10 @@ def test_multipart_copy(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + **{'ChecksumMode': "ENABLED"}, ), self.create_mpu_request( 'mybucket', 'mykey', RequestPayer='requester' @@ -1576,7 +1590,11 @@ def test_download(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request(self.accesspoint_arn, 'mykey'), + self.head_object_request( + self.accesspoint_arn, + 'mykey', + **{'ChecksumMode': "ENABLED"}, + ), self.get_object_request(self.accesspoint_arn, 'mykey'), ] ) @@ -1610,7 +1628,11 @@ def test_copy(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request(self.accesspoint_arn, 'mykey'), + self.head_object_request( + self.accesspoint_arn, + 'mykey', + **{'ChecksumMode': "ENABLED"}, + ), self.copy_object_request( self.accesspoint_arn, 'mykey', diff --git a/tests/functional/s3/test_mv_command.py b/tests/functional/s3/test_mv_command.py index 0ef997b50c13..a9d03373dfb6 100644 --- a/tests/functional/s3/test_mv_command.py +++ b/tests/functional/s3/test_mv_command.py @@ -51,6 +51,7 @@ def test_dryrun_move(self): { 'Bucket': 'bucket', 'Key': 'key.txt', + 'ChecksumMode': 'ENABLED', }, ) ] @@ -144,6 +145,7 @@ def test_download_move_with_request_payer(self): 'Bucket': 'mybucket', 'Key': 'mykey', 'RequestPayer': 'requester', + 'ChecksumMode': 'ENABLED', }, ), ( @@ -179,7 +181,10 @@ def test_copy_move_with_request_payer(self): self.assert_operations_called( [ self.head_object_request( - 'sourcebucket', 'sourcekey', RequestPayer='requester' + 'sourcebucket', + 'sourcekey', + RequestPayer='requester', + **{'ChecksumMode': 'ENABLED'}, ), self.copy_object_request( 'sourcebucket', @@ -216,7 +221,9 @@ def test_with_copy_props(self): self.run_cmd(cmdline, expected_rc=0) self.assert_operations_called( [ - self.head_object_request('sourcebucket', 'sourcekey'), + self.head_object_request( + 'sourcebucket', 'sourcekey', **{'ChecksumMode': 'ENABLED'} + ), self.get_object_tagging_request('sourcebucket', 'sourcekey'), self.create_mpu_request('bucket', 'key', Metadata=metadata), self.upload_part_copy_request( @@ -269,7 +276,9 @@ def test_mv_does_not_delete_source_on_failed_put_tagging(self): self.run_cmd(cmdline, expected_rc=1) self.assert_operations_called( [ - self.head_object_request('sourcebucket', 'sourcekey'), + self.head_object_request( + 'sourcebucket', 'sourcekey', **{'ChecksumMode': 'ENABLED'} + ), self.get_object_tagging_request('sourcebucket', 'sourcekey'), self.create_mpu_request('bucket', 'key', Metadata=metadata), self.upload_part_copy_request( diff --git a/tests/functional/s3transfer/test_download.py b/tests/functional/s3transfer/test_download.py index 66976117a145..f5d19fa95121 100644 --- a/tests/functional/s3transfer/test_download.py +++ b/tests/functional/s3transfer/test_download.py @@ -29,6 +29,7 @@ from tests import ( BaseGeneralInterfaceTest, + ChecksumProvider, ETagProvider, FileSizeProvider, NonSeekableWriter, @@ -60,6 +61,7 @@ def setUp(self): # Create a stream to read from self.content = b'my content' self.stream = BytesIO(self.content) + self.checksum_crc32 = "AUwfuQ==" def tearDown(self): super().tearDown() @@ -106,10 +108,12 @@ def create_expected_progress_callback_info(self): # that the stream is done. return [{'bytes_transferred': 10}] - def add_head_object_response(self, expected_params=None): + def add_head_object_response(self, expected_params=None, extras=None): head_response = self.create_stubbed_responses()[0] if expected_params: head_response['expected_params'] = expected_params + if extras: + head_response['service_response'].update(extras) self.stubber.add_response(**head_response) def add_successful_get_object_responses( @@ -308,6 +312,7 @@ def test_can_provide_file_size_and_etag(self): call_kwargs['subscribers'] = [ FileSizeProvider(len(self.content)), ETagProvider(self.etag), + ChecksumProvider({}), ] future = self.manager.download(**call_kwargs) @@ -383,7 +388,9 @@ def test_download(self): 'Key': self.key, 'RequestPayer': 'requester', } - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) self.add_successful_get_object_responses(expected_params) future = self.manager.download( self.bucket, self.key, self.filename, self.extra_args @@ -395,13 +402,13 @@ def test_download(self): self.assertEqual(self.content, f.read()) def test_download_with_checksum_enabled(self): - self.extra_args['ChecksumMode'] = 'ENABLED' expected_params = { 'Bucket': self.bucket, 'Key': self.key, - 'ChecksumMode': 'ENABLED', } - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) self.add_successful_get_object_responses(expected_params) future = self.manager.download( self.bucket, self.key, self.filename, self.extra_args @@ -518,7 +525,9 @@ def test_download(self): } expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 8-9/10'] - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) self.add_successful_get_object_responses( {**expected_params, 'IfMatch': self.etag}, expected_ranges, @@ -535,14 +544,14 @@ def test_download(self): self.assertEqual(self.content, f.read()) def test_download_with_checksum_enabled(self): - self.extra_args['ChecksumMode'] = 'ENABLED' expected_params = { 'Bucket': self.bucket, 'Key': self.key, - 'ChecksumMode': 'ENABLED', } expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) self.add_successful_get_object_responses( {**expected_params, 'IfMatch': self.etag}, expected_ranges ) @@ -564,7 +573,9 @@ def test_download_raises_if_content_range_mismatch(self): expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] # Note that the final retrieved range should be `bytes 8-9/10`. stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 7-8/10'] - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) self.add_successful_get_object_responses( {**expected_params, 'IfMatch': self.etag}, expected_ranges, @@ -584,7 +595,9 @@ def test_download_raises_if_etag_validation_fails(self): 'Key': self.key, } expected_ranges = ['bytes=0-3', 'bytes=4-7'] - self.add_head_object_response(expected_params) + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'} + ) # Add successful GetObject responses for the first 2 requests. for i, stubbed_response in enumerate( @@ -630,7 +643,7 @@ def test_download_without_etag(self): 'service_response': { 'ContentLength': len(self.content), }, - 'expected_params': expected_params, + 'expected_params': expected_params | {'ChecksumMode': 'ENABLED'}, } self.stubber.add_response(**head_object_response) @@ -647,3 +660,63 @@ def test_download_without_etag(self): # Ensure that the contents are correct with open(self.filename, 'rb') as f: self.assertEqual(self.content, f.read()) + + def test_download_full_object_checksum_validation(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 8-9/10'] + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'}, + { + "ChecksumCRC32": self.checksum_crc32, + "ChecksumType": "FULL_OBJECT", + }, + ) + self.add_successful_get_object_responses( + {**expected_params, 'IfMatch': self.etag}, + expected_ranges, + [{"ContentRange": r} for r in stubbed_ranges], + ) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_download_full_object_checksum_validation_mismatch_raises(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + stubbed_ranges = ['bytes 0-3/10', 'bytes 4-7/10', 'bytes 8-9/10'] + self.add_head_object_response( + expected_params | {'ChecksumMode': 'ENABLED'}, + {"ChecksumCRC32": "badchecksum", "ChecksumType": "FULL_OBJECT"}, + ) + self.add_successful_get_object_responses( + {**expected_params, 'IfMatch': self.etag}, + expected_ranges, + [{"ContentRange": r} for r in stubbed_ranges], + ) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + with self.assertRaises(S3ValidationError) as e: + future.result() + self.assertIn('does not match stored checksum', str(e.exception)) + + # Ensure no data is written to disk. + self.assertFalse(os.path.exists(self.filename)) diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 765f66b2c8e9..1b1af26d66af 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -43,6 +43,7 @@ DeleteSourceFileSubscriber, DeleteSourceObjectSubscriber, DirectoryCreatorSubscriber, + ProvideChecksumSubscriber, ProvideETagSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, @@ -524,6 +525,7 @@ def test_submit(self): QueuedResultSubscriber, DirectoryCreatorSubscriber, ProvideLastModifiedTimeSubscriber, + ProvideChecksumSubscriber, ProgressResultSubscriber, DoneResultSubscriber, ] @@ -670,6 +672,7 @@ def test_submit_move_adds_delete_source_subscriber(self): DirectoryCreatorSubscriber, ProvideLastModifiedTimeSubscriber, DeleteSourceObjectSubscriber, + ProvideChecksumSubscriber, ProgressResultSubscriber, DoneResultSubscriber, ] diff --git a/tests/unit/customizations/s3/test_subscribers.py b/tests/unit/customizations/s3/test_subscribers.py index 906b4970e7d3..d8b37a71287a 100644 --- a/tests/unit/customizations/s3/test_subscribers.py +++ b/tests/unit/customizations/s3/test_subscribers.py @@ -35,6 +35,7 @@ DeleteSourceObjectSubscriber, DirectoryCreatorSubscriber, OnDoneFilteredSubscriber, + ProvideChecksumSubscriber, ProvideETagSubscriber, ProvideLastModifiedTimeSubscriber, ProvideSizeSubscriber, @@ -101,6 +102,59 @@ def test_does_not_try_to_set_etag_on_crt_transfer_future(self, caplog): assert "Not providing object ETag." in caplog.text +class TestProvideChecksumSubscriber: + @pytest.mark.parametrize( + "response,expected_checksum,expected_algorithm", + [ + ( + {"ChecksumType": "FULL_OBJECT", "ChecksumCRC32": "foobar"}, + "foobar", + "ChecksumCRC32", + ), + ( + {"ChecksumType": "COMPOSITE", "ChecksumCRC32": "foobar"}, + None, + None, + ), + ( + {"ChecksumType": "FULL_OBJECT", "ChecksumSHA256": "foobar"}, + None, + None, + ), + ], + ) + def test_checksum_set( + self, response, expected_checksum, expected_algorithm + ): + transfer_meta = TransferMeta() + transfer_future = mock.Mock(spec=TransferFuture) + transfer_future.meta = transfer_meta + assert transfer_meta.checksum_is_provided is False + + subscriber = ProvideChecksumSubscriber(response) + subscriber.on_queued(transfer_future) + assert transfer_meta.stored_checksum == expected_checksum + assert transfer_meta.checksum_algorithm == expected_algorithm + assert transfer_meta.checksum_is_provided is True + + def test_does_not_try_to_set_checksum_on_crt_transfer_future(self, caplog): + caplog.set_level(logging.DEBUG) + crt_transfer_future = mock.Mock(spec=CRTTransferFuture) + crt_transfer_future.meta = CRTTransferMeta() + + subscriber = ProvideChecksumSubscriber( + { + "ChecksumType": "FULL_OBJECT", + "ChecksumCRC32": "foobar", + } + ) + subscriber.on_queued(crt_transfer_future) + assert not hasattr(crt_transfer_future.meta, 'stored_checksum') + assert not hasattr(crt_transfer_future.meta, 'checksum_algorithm') + + assert "Not providing stored checksum." in caplog.text + + class OnDoneFilteredRecordingSubscriber(OnDoneFilteredSubscriber): def __init__(self): self.on_success_calls = [] diff --git a/tests/unit/s3transfer/test_checksums.py b/tests/unit/s3transfer/test_checksums.py new file mode 100644 index 000000000000..b4e057112cd1 --- /dev/null +++ b/tests/unit/s3transfer/test_checksums.py @@ -0,0 +1,155 @@ +from io import BytesIO + +import pytest +from botocore.httpchecksum import CrtCrc32Checksum +from s3transfer.checksums import ( + FullObjectChecksum, + PartStreamingChecksumBody, + provide_checksum_to_meta, +) +from s3transfer.exceptions import S3ValidationError + +from tests import mock + + +def read_from_stream(stream): + data = b"" + val = stream.read() + while val: + data += val + val = stream.read() + return data + + +@pytest.fixture +def stream(): + return BytesIO(b'hello world') + + +@pytest.fixture +def mock_full_object_checksum(): + mock_full_object_checksum = mock.MagicMock(FullObjectChecksum) + mock_full_object_checksum.checksum_algorithm = 'ChecksumCRC32' + return mock_full_object_checksum + + +@pytest.fixture +def part_streaming_checksum_body(stream, mock_full_object_checksum): + return PartStreamingChecksumBody(stream, 0, mock_full_object_checksum) + + +class TestPartStreamingChecksumBody: + def test_basic_read(self, part_streaming_checksum_body): + read_data = read_from_stream(part_streaming_checksum_body) + assert read_data == b'hello world' + + def test_sets_part_checksum( + self, stream, mock_full_object_checksum, part_streaming_checksum_body + ): + read_from_stream(part_streaming_checksum_body) + mock_full_object_checksum.set_part_checksum.assert_called_with( + 0, 222957957 + ) + + def test_reuses_checksum(self, stream, mock_full_object_checksum): + mock_checksum = mock.MagicMock(CrtCrc32Checksum) + mock_checksum.int_crc = 111111111 + stream.checksum = mock_checksum + + part_streaming_checksum_body = PartStreamingChecksumBody( + stream, 0, mock_full_object_checksum + ) + read_from_stream(part_streaming_checksum_body) + mock_full_object_checksum.set_part_checksum.assert_called_with( + 0, 111111111 + ) + + +class TestFullObjectChecksum: + def generate_part_checksum_bodies(self, n, full_object_checksum): + parts = [] + for i in range(n): + parts.append( + PartStreamingChecksumBody( + BytesIO(f"Part{i}".encode()), + i * 5, + full_object_checksum, + ) + ) + return parts + + @pytest.mark.parametrize( + "n_parts,expected", [(2, "Fd9B+g=="), (10, "eywYgg==")] + ) + def test_calculated_checksum(self, n_parts, expected): + full_object_checksum = FullObjectChecksum("ChecksumCRC32", 5 * n_parts) + parts = self.generate_part_checksum_bodies( + n_parts, full_object_checksum + ) + for part in parts: + read_from_stream(part) + assert full_object_checksum.calculated_checksum == expected + + def test_checksum_mismatch_raises(self): + n_parts = 10 + full_object_checksum = FullObjectChecksum("ChecksumCRC32", 5 * n_parts) + parts = self.generate_part_checksum_bodies( + n_parts, full_object_checksum + ) + for part in parts: + read_from_stream(part) + full_object_checksum.set_stored_checksum('foobar') + with pytest.raises(S3ValidationError) as exc_info: + full_object_checksum.validate() + assert "does not match stored checksum" in str(exc_info.value) + + @pytest.mark.parametrize("content_length", [0, 49, 51]) + def test_wrong_content_length_raises(self, content_length): + n_parts = 10 + # Note that total content length should be `50`. + full_object_checksum = FullObjectChecksum( + "ChecksumCRC32", content_length + ) + parts = self.generate_part_checksum_bodies( + n_parts, full_object_checksum + ) + for part in parts: + read_from_stream(part) + full_object_checksum.set_stored_checksum("eywYgg==") + with pytest.raises(S3ValidationError) as exc_info: + full_object_checksum.validate() + assert "does not match stored checksum" in str(exc_info.value) + + +class TestProvideChecksumToMeta: + def test_provides_checksum_to_meta(self): + response = { + "ChecksumType": "FULL_OBJECT", + "ChecksumCRC32": "foobar", + } + mock_transfer_meta = mock.Mock() + provide_checksum_to_meta(response, mock_transfer_meta) + mock_transfer_meta.provide_checksum_algorithm.assert_called_with( + "ChecksumCRC32" + ) + mock_transfer_meta.provide_stored_checksum.assert_called_with("foobar") + + def test_provides_none_if_composite(self): + response = { + "ChecksumType": "COMPOSITE", + "ChecksumCRC32": "foobar", + } + mock_transfer_meta = mock.Mock() + provide_checksum_to_meta(response, mock_transfer_meta) + mock_transfer_meta.provide_checksum_algorithm.assert_called_with(None) + mock_transfer_meta.provide_stored_checksum.assert_called_with(None) + + def test_provides_none_if_not_crc(self): + response = { + "ChecksumType": "FULL_OBJECT", + "ChecksumFooBar": "foobar", + } + mock_transfer_meta = mock.Mock() + provide_checksum_to_meta(response, mock_transfer_meta) + mock_transfer_meta.provide_checksum_algorithm.assert_called_with(None) + mock_transfer_meta.provide_stored_checksum.assert_called_with(None) diff --git a/tests/unit/s3transfer/test_download.py b/tests/unit/s3transfer/test_download.py index 57d6418d54eb..dfb49cd65ca5 100644 --- a/tests/unit/s3transfer/test_download.py +++ b/tests/unit/s3transfer/test_download.py @@ -33,8 +33,9 @@ IORenameFileTask, IOStreamingWriteTask, IOWriteTask, + ValidateChecksumTask, ) -from s3transfer.exceptions import RetriesExceededError +from s3transfer.exceptions import RetriesExceededError, S3ValidationError from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor from s3transfer.utils import CallArgs, OSUtils @@ -48,6 +49,7 @@ mock, unittest, ) +from tests.unit.s3transfer.test_checksums import mock_full_object_checksum class DownloadException(Exception): @@ -188,6 +190,14 @@ def test_get_file_io_write_task(self): io_write_task() self.assertEqual(fileobj.writes, [(3, 'foo')]) + def test_get_validate_checksum_task(self): + self.assertIsInstance( + self.download_output_manager.get_validate_checksum_task( + mock.Mock() + ), + ValidateChecksumTask, + ) + class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest): def setUp(self): @@ -397,6 +407,7 @@ def setUp(self): self.bucket = 'mybucket' self.key = 'mykey' self.etag = 'myetag' + self.checksum_crc32 = 'AUwfuQ==' self.extra_args = {'IfMatch': self.etag} self.subscribers = [] @@ -568,6 +579,44 @@ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): # to that task submission. self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + def test_full_object_checksum_validation(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.stubber.add_response( + 'head_object', + { + 'ContentLength': len(self.content), + 'ETag': self.etag, + 'ChecksumCRC32': self.checksum_crc32, + 'ChecksumType': 'FULL_OBJECT', + }, + ) + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + def test_full_object_checksum_validation_raises(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.stubber.add_response( + 'head_object', + { + 'ContentLength': len(self.content), + 'ETag': self.etag, + 'ChecksumCRC32': 'badchecksum', + 'ChecksumType': 'FULL_OBJECT', + }, + ) + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + with self.assertRaisesRegex( + S3ValidationError, r'does not match stored checksum' + ): + self.submission_task() + self.transfer_future.result() + class TestGetObjectTask(BaseTaskTest): def setUp(self): @@ -684,6 +733,25 @@ def test_uses_bandwidth_limiter(self): [mock.call(mock.ANY, self.transfer_coordinator)], ) + def test_uses_full_object_checksum(self): + full_object_checksum = mock.Mock() + full_object_checksum.checksum_algorithm = 'ChecksumCRC32' + + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task( + full_object_checksum=full_object_checksum + ) + task() + + self.stubber.assert_no_pending_responses() + full_object_checksum.set_part_checksum.assert_called_once_with( + 0, 21766073 + ) + def test_retries_succeeds(self): self.stubber.add_response( 'get_object', @@ -884,6 +952,17 @@ def test_main(self): self.assertTrue(f.closed) +class TestValidateChecksumTask(BaseIOTaskTest): + def test_main(self): + mock_full_object_checksum = mock.Mock() + task = self.get_task( + ValidateChecksumTask, + main_kwargs={'full_object_checksum': mock_full_object_checksum}, + ) + task() + mock_full_object_checksum.validate.assert_called_once() + + class TestDownloadChunkIterator(unittest.TestCase): def test_iter(self): content = b'my content' diff --git a/tests/unit/s3transfer/test_futures.py b/tests/unit/s3transfer/test_futures.py index a4c015b10cae..99f82731d496 100644 --- a/tests/unit/s3transfer/test_futures.py +++ b/tests/unit/s3transfer/test_futures.py @@ -169,6 +169,27 @@ def test_user_context(self): self.transfer_meta.user_context['foo'] = 'bar' self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) + def test_checksum(self): + self.transfer_meta.provide_stored_checksum('foo') + self.transfer_meta.provide_checksum_algorithm('bar') + self.assertEqual(self.transfer_meta.stored_checksum, 'foo') + self.assertEqual(self.transfer_meta.checksum_algorithm, 'bar') + self.assertTrue(self.transfer_meta.checksum_is_provided) + + def test_checksum_not_provided(self): + self.assertFalse(self.transfer_meta.checksum_is_provided) + transfer_meta_with_checksum = TransferMeta() + transfer_meta_with_checksum.provide_stored_checksum('foo') + self.assertFalse(transfer_meta_with_checksum.checksum_is_provided) + transfer_meta_with_algorithm = TransferMeta() + transfer_meta_with_algorithm.provide_checksum_algorithm('bar') + self.assertFalse(transfer_meta_with_algorithm.checksum_is_provided) + + def test_checksum_provided_none(self): + self.transfer_meta.provide_stored_checksum(None) + self.transfer_meta.provide_checksum_algorithm(None) + self.assertTrue(self.transfer_meta.checksum_is_provided) + class TestTransferCoordinator(unittest.TestCase): def setUp(self): diff --git a/tests/utils/s3transfer/__init__.py b/tests/utils/s3transfer/__init__.py index 4c1cf87b8c91..544ac5f62c67 100644 --- a/tests/utils/s3transfer/__init__.py +++ b/tests/utils/s3transfer/__init__.py @@ -23,6 +23,7 @@ import botocore.session from botocore.stub import Stubber +from s3transfer.checksums import provide_checksum_to_meta from s3transfer.futures import ( IN_MEMORY_DOWNLOAD_TAG, IN_MEMORY_UPLOAD_TAG, @@ -162,6 +163,14 @@ def on_queued(self, future, **kwargs): future.meta.provide_object_etag(self.etag) +class ChecksumProvider: + def __init__(self, response_data): + self.response_data = response_data + + def on_queued(self, future, **kwargs): + provide_checksum_to_meta(self.response_data, future.meta) + + class FileCreator: def __init__(self): self.rootdir = tempfile.mkdtemp()