diff --git a/.github/workflows/quality-checks.yml b/.github/workflows/quality-checks.yml index d4a76066b..bcabd14d6 100644 --- a/.github/workflows/quality-checks.yml +++ b/.github/workflows/quality-checks.yml @@ -175,7 +175,7 @@ jobs: working-directory: lambdas/ack_backend id: acklambda env: - PYTHONPATH: ${{ env.LAMBDA_PATH }}/ack_backend/src:${{ github.workspace }}/ack_backend/tests + PYTHONPATH: ${{ env.LAMBDA_PATH }}/ack_backend/src:tests:${{ env.SHARED_PATH }}/src continue-on-error: true run: | poetry install diff --git a/lambdas/ack_backend/poetry.lock b/lambdas/ack_backend/poetry.lock index 58f4b06fb..adfa6f0ce 100644 --- a/lambdas/ack_backend/poetry.lock +++ b/lambdas/ack_backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "boto3" @@ -532,47 +532,48 @@ files = [ [[package]] name = "moto" -version = "4.2.14" -description = "" +version = "5.1.14" +description = "A library that allows you to easily mock out tests based on AWS infrastructure" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "moto-4.2.14-py2.py3-none-any.whl", hash = "sha256:6d242dbbabe925bb385ddb6958449e5c827670b13b8e153ed63f91dbdb50372c"}, - {file = "moto-4.2.14.tar.gz", hash = "sha256:8f9263ca70b646f091edcc93e97cda864a542e6d16ed04066b1370ed217bd190"}, + {file = "moto-5.1.14-py3-none-any.whl", hash = "sha256:b9767848953beaf6650f1fd91615a3bcef84d93bd00603fa64dae38c656548e8"}, + {file = "moto-5.1.14.tar.gz", hash = "sha256:450690abb0b152fea7f93e497ac2172f15d8a838b15f22b514db801a6b857ae4"}, ] [package.dependencies] boto3 = ">=1.9.201" -botocore = ">=1.12.201" -cryptography = ">=3.3.1" +botocore = ">=1.20.88,<1.35.45 || >1.35.45,<1.35.46 || >1.35.46" +cryptography = ">=35.0.0" Jinja2 = ">=2.10.1" python-dateutil = ">=2.1,<3.0.0" requests = ">=2.5" -responses = ">=0.13.0" +responses = ">=0.15.0,<0.25.5 || >0.25.5" werkzeug = ">=0.5,<2.2.0 || >2.2.0,<2.2.1 || >2.2.1" xmltodict = "*" [package.extras] -all = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "ecdsa (!=0.15)", "graphql-core", "jsondiff (>=1.1.2)", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.0)", "pyparsing (>=3.0.7)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "setuptools", "sshpubkeys (>=3.1.0)"] -apigateway = ["PyYAML (>=5.1)", "ecdsa (!=0.15)", "openapi-spec-validator (>=0.5.0)", "python-jose[cryptography] (>=3.1.0,<4.0.0)"] -apigatewayv2 = ["PyYAML (>=5.1)"] +all = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "jsonpath_ng", "jsonschema", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.6.1)", "pyparsing (>=3.0.7)", "setuptools"] +apigateway = ["PyYAML (>=5.1)", "joserfc (>=0.9.0)", "openapi-spec-validator (>=0.5.0)"] +apigatewayv2 = ["PyYAML (>=5.1)", "openapi-spec-validator (>=0.5.0)"] appsync = ["graphql-core"] awslambda = ["docker (>=3.0.0)"] batch = ["docker (>=3.0.0)"] -cloudformation = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "ecdsa (!=0.15)", "graphql-core", "jsondiff (>=1.1.2)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.0)", "pyparsing (>=3.0.7)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "setuptools", "sshpubkeys (>=3.1.0)"] -cognitoidp = ["ecdsa (!=0.15)", "python-jose[cryptography] (>=3.1.0,<4.0.0)"] -dynamodb = ["docker (>=3.0.0)", "py-partiql-parser (==0.5.0)"] -dynamodbstreams = ["docker (>=3.0.0)", "py-partiql-parser (==0.5.0)"] -ec2 = ["sshpubkeys (>=3.1.0)"] +cloudformation = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.6.1)", "pyparsing (>=3.0.7)", "setuptools"] +cognitoidp = ["joserfc (>=0.9.0)"] +dynamodb = ["docker (>=3.0.0)", "py-partiql-parser (==0.6.1)"] +dynamodbstreams = ["docker (>=3.0.0)", "py-partiql-parser (==0.6.1)"] +events = ["jsonpath_ng"] glue = ["pyparsing (>=3.0.7)"] -iotdata = ["jsondiff (>=1.1.2)"] -proxy = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=2.5.1)", "ecdsa (!=0.15)", "graphql-core", "jsondiff (>=1.1.2)", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.0)", "pyparsing (>=3.0.7)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "setuptools", "sshpubkeys (>=3.1.0)"] -resourcegroupstaggingapi = ["PyYAML (>=5.1)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "ecdsa (!=0.15)", "graphql-core", "jsondiff (>=1.1.2)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.0)", "pyparsing (>=3.0.7)", "python-jose[cryptography] (>=3.1.0,<4.0.0)"] -s3 = ["PyYAML (>=5.1)", "py-partiql-parser (==0.5.0)"] -s3crc32c = ["PyYAML (>=5.1)", "crc32c", "py-partiql-parser (==0.5.0)"] -server = ["PyYAML (>=5.1)", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "ecdsa (!=0.15)", "flask (!=2.2.0,!=2.2.1)", "flask-cors", "graphql-core", "jsondiff (>=1.1.2)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.5.0)", "pyparsing (>=3.0.7)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "setuptools", "sshpubkeys (>=3.1.0)"] +proxy = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=2.5.1)", "graphql-core", "joserfc (>=0.9.0)", "jsonpath_ng", "multipart", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.6.1)", "pyparsing (>=3.0.7)", "setuptools"] +quicksight = ["jsonschema"] +resourcegroupstaggingapi = ["PyYAML (>=5.1)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "graphql-core", "joserfc (>=0.9.0)", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.6.1)", "pyparsing (>=3.0.7)"] +s3 = ["PyYAML (>=5.1)", "py-partiql-parser (==0.6.1)"] +s3crc32c = ["PyYAML (>=5.1)", "crc32c", "py-partiql-parser (==0.6.1)"] +server = ["PyYAML (>=5.1)", "antlr4-python3-runtime", "aws-xray-sdk (>=0.93,!=0.96)", "cfn-lint (>=0.40.0)", "docker (>=3.0.0)", "flask (!=2.2.0,!=2.2.1)", "flask-cors", "graphql-core", "joserfc (>=0.9.0)", "jsonpath_ng", "openapi-spec-validator (>=0.5.0)", "py-partiql-parser (==0.6.1)", "pyparsing (>=3.0.7)", "setuptools"] ssm = ["PyYAML (>=5.1)"] +stepfunctions = ["antlr4-python3-runtime", "jsonpath_ng"] xray = ["aws-xray-sdk (>=0.93,!=0.96)", "setuptools"] [[package]] @@ -836,4 +837,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = "~3.11" -content-hash = "31f699335ebf55c2b60a1644c8732ae1f6ac396e82fb64b264abb33594e35ab2" +content-hash = "8a50b352a14a3ba5d16c750f254ce9d6b6fada65cc912e1665d4a37192c5c24e" diff --git a/lambdas/ack_backend/pyproject.toml b/lambdas/ack_backend/pyproject.toml index b65647882..d08181aaa 100644 --- a/lambdas/ack_backend/pyproject.toml +++ b/lambdas/ack_backend/pyproject.toml @@ -14,7 +14,7 @@ python = "~3.11" boto3 = "~1.40.45" mypy-boto3-dynamodb = "^1.40.44" freezegun = "^1.5.2" -moto = "^4" +moto = "^5.1.14" coverage = "^7.10.7" diff --git a/lambdas/ack_backend/src/ack_processor.py b/lambdas/ack_backend/src/ack_processor.py index de84fee57..0c44ef6e1 100644 --- a/lambdas/ack_backend/src/ack_processor.py +++ b/lambdas/ack_backend/src/ack_processor.py @@ -2,12 +2,13 @@ import json from logging_decorators import ack_lambda_handler_logging_decorator -from update_ack_file import update_ack_file +from update_ack_file import update_ack_file, complete_batch_file_process +from utils_for_ack_lambda import is_ack_processing_complete from convert_message_to_ack_row import convert_message_to_ack_row @ack_lambda_handler_logging_decorator -def lambda_handler(event, context): +def lambda_handler(event, _): """ Ack lambda handler. For each record: each message in the array of messages is converted to an ack row, @@ -22,6 +23,7 @@ def lambda_handler(event, context): message_id = None ack_data_rows = [] + total_ack_rows_processed = 0 for i, record in enumerate(event["Records"]): @@ -31,10 +33,8 @@ def lambda_handler(event, context): raise ValueError("Could not load incoming message body") from body_json_error if i == 0: - # IMPORTANT NOTE: An assumption is made here that the file_key and created_at_formatted_string are the same - # for all messages in the event. The use of FIFO SQS queues ensures that this is the case, provided that - # there is only one file processing at a time for each supplier queue (combination of supplier and vaccine - # type). + # The SQS FIFO MessageGroupId that this lambda consumes from is based on the source filename + created at + # datetime. Therefore, can safely retrieve file metadata from the first record in the list file_key = incoming_message_body[0].get("file_key") message_id = (incoming_message_body[0].get("row_id", "")).split("^")[0] vaccine_type = incoming_message_body[0].get("vaccine_type") @@ -44,14 +44,16 @@ def lambda_handler(event, context): for message in incoming_message_body: ack_data_rows.append(convert_message_to_ack_row(message, created_at_formatted_string)) - update_ack_file( - file_key, - message_id, - supplier, - vaccine_type, - created_at_formatted_string, - ack_data_rows, - ) + update_ack_file(file_key, created_at_formatted_string, ack_data_rows) + + # Get the row count of the final processed record + # Format of the row id is {batch_message_id}^{row_number} + total_ack_rows_processed = int(incoming_message_body[-1].get("row_id", "").split("^")[1]) + + if is_ack_processing_complete(message_id, total_ack_rows_processed): + complete_batch_file_process( + message_id, supplier, vaccine_type, created_at_formatted_string, file_key, total_ack_rows_processed + ) return { "statusCode": 200, diff --git a/lambdas/ack_backend/src/audit_table.py b/lambdas/ack_backend/src/audit_table.py index 48a723d27..b21ab08d7 100644 --- a/lambdas/ack_backend/src/audit_table.py +++ b/lambdas/ack_backend/src/audit_table.py @@ -1,5 +1,6 @@ """Add the filename to the audit table and check for duplicates.""" +from typing import Optional from common.clients import dynamodb_client, logger from common.models.errors import UnhandledAuditTableError from constants import AUDIT_TABLE_NAME, FileStatus, AuditTableKeys @@ -28,3 +29,17 @@ def change_audit_table_status_to_processed(file_key: str, message_id: str) -> No except Exception as error: # pylint: disable = broad-exception-caught logger.error(error) raise UnhandledAuditTableError(error) from error + + +def get_record_count_by_message_id(event_message_id: str) -> Optional[int]: + """Retrieves full audit entry by unique event message ID""" + audit_record = dynamodb_client.get_item( + TableName=AUDIT_TABLE_NAME, Key={AuditTableKeys.MESSAGE_ID: {"S": event_message_id}} + ) + + record_count = audit_record.get("Item", {}).get(AuditTableKeys.RECORD_COUNT, {}).get("N") + + if not record_count: + return None + + return int(record_count) diff --git a/lambdas/ack_backend/src/constants.py b/lambdas/ack_backend/src/constants.py index f6229dbd0..7e8c2ceec 100644 --- a/lambdas/ack_backend/src/constants.py +++ b/lambdas/ack_backend/src/constants.py @@ -4,6 +4,11 @@ AUDIT_TABLE_NAME = os.getenv("AUDIT_TABLE_NAME") +COMPLETED_ACK_DIR = "forwardedFile" +TEMP_ACK_DIR = "TempAck" +BATCH_FILE_PROCESSING_DIR = "processing" +BATCH_FILE_ARCHIVE_DIR = "archive" + def get_source_bucket_name() -> str: """Get the SOURCE_BUCKET_NAME environment from environment variables.""" @@ -30,6 +35,7 @@ class AuditTableKeys: FILENAME = "filename" MESSAGE_ID = "message_id" QUEUE_NAME = "queue_name" + RECORD_COUNT = "record_count" STATUS = "status" TIMESTAMP = "timestamp" diff --git a/lambdas/ack_backend/src/logging_decorators.py b/lambdas/ack_backend/src/logging_decorators.py index 9327c9e9e..e78cb7671 100644 --- a/lambdas/ack_backend/src/logging_decorators.py +++ b/lambdas/ack_backend/src/logging_decorators.py @@ -68,7 +68,7 @@ def wrapper(message, created_at_formatted_string): return wrapper -def upload_ack_file_logging_decorator(func): +def complete_batch_file_process_logging_decorator(func): """This decorator logs when record processing is complete.""" @wraps(func) diff --git a/lambdas/ack_backend/src/update_ack_file.py b/lambdas/ack_backend/src/update_ack_file.py index 3ae60dd86..d851cb416 100644 --- a/lambdas/ack_backend/src/update_ack_file.py +++ b/lambdas/ack_backend/src/update_ack_file.py @@ -2,12 +2,18 @@ from botocore.exceptions import ClientError from io import StringIO, BytesIO -from typing import Optional from audit_table import change_audit_table_status_to_processed from common.clients import get_s3_client, logger -from constants import ACK_HEADERS, get_source_bucket_name, get_ack_bucket_name -from logging_decorators import upload_ack_file_logging_decorator -from utils_for_ack_lambda import get_row_count +from constants import ( + ACK_HEADERS, + get_source_bucket_name, + get_ack_bucket_name, + COMPLETED_ACK_DIR, + TEMP_ACK_DIR, + BATCH_FILE_PROCESSING_DIR, + BATCH_FILE_ARCHIVE_DIR, +) +from logging_decorators import complete_batch_file_process_logging_decorator def create_ack_data( @@ -45,6 +51,35 @@ def create_ack_data( } +@complete_batch_file_process_logging_decorator +def complete_batch_file_process( + message_id: str, + supplier: str, + vaccine_type: str, + created_at_formatted_string: str, + file_key: str, + total_ack_rows_processed: int, +) -> dict: + """Mark the batch file as processed. This involves moving the ack and original file to destinations and updating + the audit table status""" + ack_filename = f"{file_key.replace('.csv', f'_BusAck_{created_at_formatted_string}.csv')}" + + move_file(get_ack_bucket_name(), f"{TEMP_ACK_DIR}/{ack_filename}", f"{COMPLETED_ACK_DIR}/{ack_filename}") + move_file( + get_source_bucket_name(), f"{BATCH_FILE_PROCESSING_DIR}/{file_key}", f"{BATCH_FILE_ARCHIVE_DIR}/{file_key}" + ) + + change_audit_table_status_to_processed(file_key, message_id) + + return { + "message_id": message_id, + "file_key": file_key, + "supplier": supplier, + "vaccine_type": vaccine_type, + "row_count": total_ack_rows_processed, + } + + def obtain_current_ack_content(temp_ack_file_key: str) -> StringIO: """Returns the current ack file content if the file exists, or else initialises the content with the ack headers.""" try: @@ -65,76 +100,27 @@ def obtain_current_ack_content(temp_ack_file_key: str) -> StringIO: return accumulated_csv_content -@upload_ack_file_logging_decorator -def upload_ack_file( - temp_ack_file_key: str, - message_id: str, - supplier: str, - vaccine_type: str, - accumulated_csv_content: StringIO, - ack_data_rows: list, - archive_ack_file_key: str, +def update_ack_file( file_key: str, -) -> Optional[dict]: - """Adds the data row to the uploaded ack file""" + created_at_formatted_string: str, + ack_data_rows: list, +) -> None: + """Updates the ack file with the new data row based on the given arguments""" + ack_filename = f"{file_key.replace('.csv', f'_BusAck_{created_at_formatted_string}.csv')}" + temp_ack_file_key = f"{TEMP_ACK_DIR}/{ack_filename}" + archive_ack_file_key = f"{COMPLETED_ACK_DIR}/{ack_filename}" + accumulated_csv_content = obtain_current_ack_content(temp_ack_file_key) + for row in ack_data_rows: data_row_str = [str(item) for item in row.values()] cleaned_row = "|".join(data_row_str).replace(" |", "|").replace("| ", "|").strip() accumulated_csv_content.write(cleaned_row + "\n") - csv_file_like_object = BytesIO(accumulated_csv_content.getvalue().encode("utf-8")) + csv_file_like_object = BytesIO(accumulated_csv_content.getvalue().encode("utf-8")) ack_bucket_name = get_ack_bucket_name() - source_bucket_name = get_source_bucket_name() get_s3_client().upload_fileobj(csv_file_like_object, ack_bucket_name, temp_ack_file_key) - - row_count_source = get_row_count(source_bucket_name, f"processing/{file_key}") - row_count_destination = get_row_count(ack_bucket_name, temp_ack_file_key) - # TODO: Should we check for > and if so what handling is required - if row_count_destination == row_count_source: - move_file(ack_bucket_name, temp_ack_file_key, archive_ack_file_key) - move_file(source_bucket_name, f"processing/{file_key}", f"archive/{file_key}") - - # Update the audit table - change_audit_table_status_to_processed(file_key, message_id) - - # Ingestion of this file is complete - result = { - "message_id": message_id, - "file_key": file_key, - "supplier": supplier, - "vaccine_type": vaccine_type, - "row_count": row_count_source - 1, - } - else: - result = None logger.info("Ack file updated to %s: %s", ack_bucket_name, archive_ack_file_key) - return result - - -def update_ack_file( - file_key: str, - message_id: str, - supplier: str, - vaccine_type: str, - created_at_formatted_string: str, - ack_data_rows: list, -) -> None: - """Updates the ack file with the new data row based on the given arguments""" - ack_filename = f"{file_key.replace('.csv', f'_BusAck_{created_at_formatted_string}.csv')}" - temp_ack_file_key = f"TempAck/{ack_filename}" - archive_ack_file_key = f"forwardedFile/{ack_filename}" - accumulated_csv_content = obtain_current_ack_content(temp_ack_file_key) - upload_ack_file( - temp_ack_file_key, - message_id, - supplier, - vaccine_type, - accumulated_csv_content, - ack_data_rows, - archive_ack_file_key, - file_key, - ) def move_file(bucket_name: str, source_file_key: str, destination_file_key: str) -> None: diff --git a/lambdas/ack_backend/src/utils_for_ack_lambda.py b/lambdas/ack_backend/src/utils_for_ack_lambda.py index 29b7d4c66..8c76d7fff 100644 --- a/lambdas/ack_backend/src/utils_for_ack_lambda.py +++ b/lambdas/ack_backend/src/utils_for_ack_lambda.py @@ -1,12 +1,21 @@ """Utils for ack lambda""" -from common.clients import get_s3_client +from audit_table import get_record_count_by_message_id +_BATCH_EVENT_ID_TO_RECORD_COUNT_MAP: dict[str, int] = {} -def get_row_count(bucket_name: str, file_key: str) -> int: - """ - Looks in the given bucket and returns the count of the number of lines in the given file. - NOTE: Blank lines are not included in the count. - """ - response = get_s3_client().get_object(Bucket=bucket_name, Key=file_key) - return sum(1 for line in response["Body"].iter_lines() if line.strip()) + +def is_ack_processing_complete(batch_event_message_id: str, processed_ack_count: int) -> bool: + """Checks if we have received all the acknowledgement rows for the original source file. Also caches the value of + the source file record count to reduce traffic to DynamoDB""" + if batch_event_message_id in _BATCH_EVENT_ID_TO_RECORD_COUNT_MAP: + return _BATCH_EVENT_ID_TO_RECORD_COUNT_MAP[batch_event_message_id] == processed_ack_count + + record_count = get_record_count_by_message_id(batch_event_message_id) + + if not record_count: + # Record count is not set on the audit item until all rows have been preprocessed and sent to Kinesis + return False + + _BATCH_EVENT_ID_TO_RECORD_COUNT_MAP[batch_event_message_id] = record_count + return record_count == processed_ack_count diff --git a/lambdas/ack_backend/tests/test_ack_processor.py b/lambdas/ack_backend/tests/test_ack_processor.py index 82e6a3262..ebaffe885 100644 --- a/lambdas/ack_backend/tests/test_ack_processor.py +++ b/lambdas/ack_backend/tests/test_ack_processor.py @@ -3,30 +3,33 @@ import unittest import os import json -from unittest.mock import patch +from unittest.mock import patch, Mock from io import StringIO from boto3 import client as boto3_client -from moto import mock_s3, mock_firehose +from moto import mock_aws -from tests.utils.mock_environment_variables import ( +from utils.mock_environment_variables import ( + AUDIT_TABLE_NAME, MOCK_ENVIRONMENT_DICT, BucketNames, REGION_NAME, ) -from tests.utils.generic_setup_and_teardown_for_ack_backend import ( +from utils.generic_setup_and_teardown_for_ack_backend import ( GenericSetUp, GenericTearDown, ) -from tests.utils.utils_for_ack_backend_tests import ( - setup_existing_ack_file, +from utils.utils_for_ack_backend_tests import ( + add_audit_entry_to_table, validate_ack_file_content, + generate_sample_existing_ack_content, ) -from tests.utils.values_for_ack_backend_tests import ( +from utils.values_for_ack_backend_tests import ( DiagnosticsDictionaries, MOCK_MESSAGE_DETAILS, ValidValues, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS, ) +from utils_for_ack_lambda import _BATCH_EVENT_ID_TO_RECORD_COUNT_MAP with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT): from ack_processor import lambda_handler @@ -39,18 +42,17 @@ @patch.dict(os.environ, MOCK_ENVIRONMENT_DICT) -@mock_s3 -@mock_firehose +@patch("audit_table.AUDIT_TABLE_NAME", AUDIT_TABLE_NAME) +@mock_aws class TestAckProcessor(unittest.TestCase): """Tests for the ack processor lambda handler.""" def setUp(self) -> None: self.s3_client = boto3_client("s3", region_name=REGION_NAME) self.firehose_client = boto3_client("firehose", region_name=REGION_NAME) - GenericSetUp(self.s3_client, self.firehose_client) + self.dynamodb_client = boto3_client("dynamodb", region_name=REGION_NAME) + GenericSetUp(self.s3_client, self.firehose_client, self.dynamodb_client) - # MOCK SOURCE FILE WITH 100 ROWS TO SIMULATE THE SCENARIO WHERE THE ACK FILE IS NO FULL. - # TODO: Test all other scenarios. mock_source_file_with_100_rows = StringIO("\n".join(f"Row {i}" for i in range(1, 101))) self.s3_client.put_object( Bucket=BucketNames.SOURCE, @@ -61,14 +63,14 @@ def setUp(self) -> None: self.mock_logger_info = self.logger_info_patcher.start() def tearDown(self) -> None: - GenericTearDown(self.s3_client, self.firehose_client) + GenericTearDown(self.s3_client, self.firehose_client, self.dynamodb_client) self.mock_logger_info.stop() @staticmethod def generate_event(test_messages: list[dict]) -> dict: """ Returns an event where each message in the incoming message body list is based on a standard mock message, - updated with the details from the corresponsing message in the given test_messages list. + updated with the details from the corresponding message in the given test_messages list. """ incoming_message_body = [ ( @@ -80,8 +82,38 @@ def generate_event(test_messages: list[dict]) -> dict: ] return {"Records": [{"body": json.dumps(incoming_message_body)}]} + def assert_ack_and_source_file_locations_correct( + self, + source_file_key: str, + tmp_ack_file_key: str, + complete_ack_file_key: str, + is_complete: bool, + ) -> None: + """Helper function to check the ack and source files have not been moved as the processing is not yet + complete""" + if is_complete: + ack_file = self.s3_client.get_object(Bucket=BucketNames.DESTINATION, Key=complete_ack_file_key) + else: + ack_file = self.s3_client.get_object(Bucket=BucketNames.DESTINATION, Key=tmp_ack_file_key) + self.assertIsNotNone(ack_file["Body"].read()) + + full_src_file_key = f"archive/{source_file_key}" if is_complete else f"processing/{source_file_key}" + src_file = self.s3_client.get_object(Bucket=BucketNames.SOURCE, Key=full_src_file_key) + self.assertIsNotNone(src_file["Body"].read()) + + def assert_audit_entry_status_equals(self, message_id: str, status: str) -> None: + """Checks the audit entry status is as expected""" + audit_entry = self.dynamodb_client.get_item( + TableName=AUDIT_TABLE_NAME, Key={"message_id": {"S": message_id}} + ).get("Item") + + actual_status = audit_entry.get("status", {}).get("S") + self.assertEqual(actual_status, status) + def test_lambda_handler_main_multiple_records(self): """Test lambda handler with multiple records.""" + # Set up an audit entry which does not yet have record_count recorded + add_audit_entry_to_table(self.dynamodb_client, "row") # First array of messages: all successful. Rows 1 to 3 array_of_success_messages = [ { @@ -147,68 +179,41 @@ def test_lambda_handler_main_multiple_records(self): def test_lambda_handler_main(self): """Test lambda handler with consitent ack_file_name and message_template.""" + # Set up an audit entry which does not yet have record_count recorded + add_audit_entry_to_table(self.dynamodb_client, "row") test_cases = [ { "description": "Multiple messages: all successful", - "messages": [{"row_id": f"row_{i+1}"} for i in range(10)], + "messages": [{"row_id": f"row^{i+1}"} for i in range(10)], }, { "description": "Multiple messages: all with diagnostics (failure messages)", "messages": [ - { - "row_id": "row_1", - "diagnostics": DiagnosticsDictionaries.UNIQUE_ID_MISSING, - }, - { - "row_id": "row_2", - "diagnostics": DiagnosticsDictionaries.NO_PERMISSIONS, - }, - { - "row_id": "row_3", - "diagnostics": DiagnosticsDictionaries.RESOURCE_NOT_FOUND_ERROR, - }, + {"row_id": "row^1", "diagnostics": DiagnosticsDictionaries.UNIQUE_ID_MISSING}, + {"row_id": "row^2", "diagnostics": DiagnosticsDictionaries.NO_PERMISSIONS}, + {"row_id": "row^3", "diagnostics": DiagnosticsDictionaries.RESOURCE_NOT_FOUND_ERROR}, ], }, { "description": "Multiple messages: mixture of success and failure messages", "messages": [ - {"row_id": "row_1", "imms_id": "TEST_IMMS_ID"}, - { - "row_id": "row_2", - "diagnostics": DiagnosticsDictionaries.UNIQUE_ID_MISSING, - }, - { - "row_id": "row_3", - "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR, - }, - {"row_id": "row_4"}, - { - "row_id": "row_5", - "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR, - }, - { - "row_id": "row_6", - "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR, - }, - {"row_id": "row_7"}, - { - "row_id": "row_8", - "diagnostics": DiagnosticsDictionaries.IDENTIFIER_DUPLICATION_ERROR, - }, + {"row_id": "row^1", "imms_id": "TEST_IMMS_ID"}, + {"row_id": "row^2", "diagnostics": DiagnosticsDictionaries.UNIQUE_ID_MISSING}, + {"row_id": "row^3", "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR}, + {"row_id": "row^4"}, + {"row_id": "row^5", "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR}, + {"row_id": "row^6", "diagnostics": DiagnosticsDictionaries.CUSTOM_VALIDATION_ERROR}, + {"row_id": "row^7"}, + {"row_id": "row^8", "diagnostics": DiagnosticsDictionaries.IDENTIFIER_DUPLICATION_ERROR}, ], }, { "description": "Single row: success", - "messages": [{"row_id": "row_1"}], + "messages": [{"row_id": "row^1"}], }, { "description": "Single row: malformed diagnostics info from forwarder", - "messages": [ - { - "row_id": "row_1", - "diagnostics": "SHOULD BE A DICTIONARY, NOT A STRING", - } - ], + "messages": [{"row_id": "row^1", "diagnostics": "SHOULD BE A DICTIONARY, NOT A STRING"}], }, ] @@ -224,23 +229,130 @@ def test_lambda_handler_main(self): Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key, ) - # Test scenario where there is an existing ack file - # TODO: None of the test cases have any existing ack file content? - with self.subTest(msg=f"Existing ack file: {test_case['description']}"): - existing_ack_file_content = test_case.get("existing_ack_file_content", "") - setup_existing_ack_file( - MOCK_MESSAGE_DETAILS.temp_ack_file_key, - existing_ack_file_content, - self.s3_client, - ) - response = lambda_handler(event=self.generate_event(test_case["messages"]), context={}) - self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) - validate_ack_file_content(self.s3_client, test_case["messages"], existing_ack_file_content) + def test_lambda_handler_updates_ack_file_but_does_not_mark_complete_when_records_still_remaining(self): + """ + Test that the batch file process is not marked as complete when not all records have been processed. + This means: + - the ack file remains in the TempAck directory + - the source file remains in the processing directory + - all ack records in the event are written to the temporary ack + """ + mock_batch_message_id = "b500efe4-6e75-4768-a38b-6127b3c7b8e0" - self.s3_client.delete_object( - Bucket=BucketNames.DESTINATION, - Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key, - ) + # Original source file had 100 records + add_audit_entry_to_table(self.dynamodb_client, mock_batch_message_id, record_count=100) + array_of_success_messages = [ + { + **BASE_SUCCESS_MESSAGE, + "row_id": f"{mock_batch_message_id}^{i}", + "imms_id": f"imms_{i}", + "local_id": f"local^{i}", + } + for i in range(1, 4) + ] + test_event = {"Records": [{"body": json.dumps(array_of_success_messages)}]} + + response = lambda_handler(event=test_event, context={}) + + self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) + validate_ack_file_content( + self.s3_client, + [*array_of_success_messages], + existing_file_content=ValidValues.ack_headers, + ) + self.assert_ack_and_source_file_locations_correct( + MOCK_MESSAGE_DETAILS.file_key, + MOCK_MESSAGE_DETAILS.temp_ack_file_key, + MOCK_MESSAGE_DETAILS.archive_ack_file_key, + is_complete=False, + ) + self.assert_audit_entry_status_equals(mock_batch_message_id, "Preprocessed") + + @patch("utils_for_ack_lambda.get_record_count_by_message_id", return_value=500) + def test_lambda_handler_uses_message_id_to_record_count_cache_to_reduce_ddb_calls(self, mock_get_record_count: Mock): + """The DynamoDB Audit table is used to store the total record count for each source file. To reduce calls each + time - this test checks that we cache the value as this lambda is called many times for large files""" + mock_batch_message_id = "622cdeea-461e-4a83-acb5-7871d47ddbcd" + + # Original source file had 500 records + add_audit_entry_to_table(self.dynamodb_client, mock_batch_message_id, record_count=500) + + message_one = [ + {**BASE_SUCCESS_MESSAGE, "row_id": f"{mock_batch_message_id}^1", "imms_id": "imms_1", "local_id": "local^1"} + ] + message_two = [ + {**BASE_SUCCESS_MESSAGE, "row_id": f"{mock_batch_message_id}^2", "imms_id": "imms_2", "local_id": "local^2"} + ] + test_event_one = {"Records": [{"body": json.dumps(message_one)}]} + test_event_two = {"Records": [{"body": json.dumps(message_two)}]} + + response = lambda_handler(event=test_event_one, context={}) + self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) + second_invocation_response = lambda_handler(event=test_event_two, context={}) + self.assertEqual(second_invocation_response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) + + # Assert that the DDB call is only performed once on the first invocation + mock_get_record_count.assert_called_once_with(mock_batch_message_id) + validate_ack_file_content( + self.s3_client, + [*message_one, *message_two], + existing_file_content=ValidValues.ack_headers, + ) + self.assert_ack_and_source_file_locations_correct( + MOCK_MESSAGE_DETAILS.file_key, + MOCK_MESSAGE_DETAILS.temp_ack_file_key, + MOCK_MESSAGE_DETAILS.archive_ack_file_key, + is_complete=False, + ) + self.assertEqual(_BATCH_EVENT_ID_TO_RECORD_COUNT_MAP[mock_batch_message_id], 500) + self.assert_audit_entry_status_equals(mock_batch_message_id, "Preprocessed") + + def test_lambda_handler_updates_ack_file_and_marks_complete_when_all_records_processed(self): + """ + Test that the batch file process is marked as complete when all records have been processed. + This means: + - the ack file moves from the TempAck directory to the forwardedFile directory + - the source file moves from the processing to the archive directory + - all ack records in the event are appended to the existing temporary ack file + - the DDB Audit Table status is set as 'Processed' + """ + mock_batch_message_id = "75db20e6-c0b5-4012-a8bc-f861a1dd4b22" + + # Original source file had 100 records + add_audit_entry_to_table(self.dynamodb_client, mock_batch_message_id, record_count=100) + + # Previous invocations have already created and added to the temp ack file + existing_ack_content = generate_sample_existing_ack_content() + self.s3_client.put_object( + Bucket=BucketNames.DESTINATION, + Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key, + Body=StringIO(existing_ack_content).getvalue(), + ) + + array_of_success_messages = [ + { + **BASE_SUCCESS_MESSAGE, + "row_id": f"{mock_batch_message_id}^{i}", + "imms_id": f"imms_{i}", + "local_id": f"local^{i}", + } + for i in range(50, 101) + ] + test_event = {"Records": [{"body": json.dumps(array_of_success_messages)}]} + + response = lambda_handler(event=test_event, context={}) + + self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) + validate_ack_file_content( + self.s3_client, [*array_of_success_messages], existing_file_content=existing_ack_content, is_complete=True + ) + self.assert_ack_and_source_file_locations_correct( + MOCK_MESSAGE_DETAILS.file_key, + MOCK_MESSAGE_DETAILS.temp_ack_file_key, + MOCK_MESSAGE_DETAILS.archive_ack_file_key, + is_complete=True, + ) + self.assert_audit_entry_status_equals(mock_batch_message_id, "Processed") def test_lambda_handler_error_scenarios(self): """Test that the lambda handler raises appropriate exceptions for malformed event data.""" diff --git a/lambdas/ack_backend/tests/test_audit_table.py b/lambdas/ack_backend/tests/test_audit_table.py index cea36ea40..2f362bf61 100644 --- a/lambdas/ack_backend/tests/test_audit_table.py +++ b/lambdas/ack_backend/tests/test_audit_table.py @@ -13,8 +13,7 @@ def setUp(self): self.mock_dynamodb_client = self.dynamodb_client_patcher.start() def tearDown(self): - self.logger_patcher.stop() - self.dynamodb_client_patcher.stop() + patch.stopall() def test_change_audit_table_status_to_processed_success(self): # Should not raise @@ -29,3 +28,21 @@ def test_change_audit_table_status_to_processed_raises(self): audit_table.change_audit_table_status_to_processed("file1", "msg1") self.assertIn("fail!", str(ctx.exception)) self.mock_logger.error.assert_called_once() + + def test_get_record_count_by_message_id_returns_the_record_count(self): + """Test that get_record_count_by_message_id retrieves the integer value of the total record count""" + test_message_id = "1234" + + self.mock_dynamodb_client.get_item.return_value = { + "Item": {"message_id": {"S": test_message_id}, "record_count": {"N": "1000"}} + } + + self.assertEqual(audit_table.get_record_count_by_message_id(test_message_id), 1000) + + def test_get_record_count_by_message_id_returns_none_if_record_count_not_set(self): + """Test that if the record count has not yet been set on the audit item then None is returned""" + test_message_id = "1234" + + self.mock_dynamodb_client.get_item.return_value = {"Item": {"message_id": {"S": test_message_id}}} + + self.assertIsNone(audit_table.get_record_count_by_message_id(test_message_id)) diff --git a/lambdas/ack_backend/tests/test_convert_message_to_ack_row.py b/lambdas/ack_backend/tests/test_convert_message_to_ack_row.py index 45d56bed2..143b9f482 100644 --- a/lambdas/ack_backend/tests/test_convert_message_to_ack_row.py +++ b/lambdas/ack_backend/tests/test_convert_message_to_ack_row.py @@ -2,15 +2,8 @@ import unittest from unittest.mock import patch -from boto3 import client as boto3_client -from moto import mock_s3, mock_firehose - -from tests.utils.mock_environment_variables import MOCK_ENVIRONMENT_DICT, REGION_NAME -from tests.utils.generic_setup_and_teardown_for_ack_backend import ( - GenericSetUp, - GenericTearDown, -) +from tests.utils.mock_environment_variables import MOCK_ENVIRONMENT_DICT from tests.utils.values_for_ack_backend_tests import ( DefaultValues, ValidValues, @@ -23,23 +16,10 @@ get_error_message_for_ack_file, ) -s3_client = boto3_client("s3", region_name=REGION_NAME) -firehose_client = boto3_client("firehose", region_name=REGION_NAME) - -@mock_firehose -@mock_s3 class TestAckProcessor(unittest.TestCase): """Tests for the ack processor lambda handler.""" - def setUp(self) -> None: - self.s3_client = boto3_client("s3", region_name=REGION_NAME) - self.firehose_client = boto3_client("firehose", region_name=REGION_NAME) - GenericSetUp(self.s3_client, self.firehose_client) - - def tearDown(self) -> None: - GenericTearDown(self.s3_client, self.firehose_client) - def test_get_error_message_for_ack_file(self): """Test the get_error_message_for_ack_file function.""" diagnastics_unclear_error_message = "Unable to determine diagnostics issue" diff --git a/lambdas/ack_backend/tests/test_splunk_logging.py b/lambdas/ack_backend/tests/test_splunk_logging.py index 90229d89e..fac195547 100644 --- a/lambdas/ack_backend/tests/test_splunk_logging.py +++ b/lambdas/ack_backend/tests/test_splunk_logging.py @@ -5,7 +5,7 @@ import json from io import StringIO from contextlib import ExitStack -from moto import mock_s3 +from moto import mock_aws from boto3 import client as boto3_client from tests.utils.values_for_ack_backend_tests import ( @@ -26,7 +26,7 @@ @patch.dict("os.environ", MOCK_ENVIRONMENT_DICT) -@mock_s3 +@mock_aws class TestLoggingDecorators(unittest.TestCase): """Tests for the ack lambda logging decorators""" @@ -97,7 +97,7 @@ def expected_lambda_handler_logs(self, success: bool, number_of_rows, ingestion_ # plus 2 seconds for the handler if it succeeds (i.e. it calls update_ack_file) or 1 second if it doesn't; # plus an extra second if ingestion is complete if success: - time_taken = f"{number_of_rows * 2 + 3}.0s" if ingestion_complete else f"{number_of_rows * 2 + 2}.0s" + time_taken = f"{number_of_rows * 2 + 3}.0s" if ingestion_complete else f"{number_of_rows * 2 + 1}.0s" else: time_taken = f"{number_of_rows * 2 + 1}.0s" @@ -118,6 +118,7 @@ def test_splunk_logging_successful_rows(self): for operation in ["CREATE", "UPDATE", "DELETE"]: with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=False), patch("common.log_decorator.logger") as mock_logger, # noqa: E999 ): # noqa: E999 result = lambda_handler( @@ -161,7 +162,7 @@ def test_splunk_logging_missing_data(self): patch("common.log_decorator.logger") as mock_logger, # noqa: E999 ): # noqa: E999 with self.assertRaises(AttributeError): - lambda_handler(event={"Records": [{"body": json.dumps([{"": "456"}])}]}, context={}) + lambda_handler(event={"Records": [{"body": json.dumps([{"": "456", "row_id": "test^1"}])}]}, context={}) expected_first_logger_info_data = {**InvalidValues.logging_with_no_values} @@ -190,7 +191,7 @@ def test_splunk_logging_statuscode_diagnostics( self, mock_send_log_to_firehose, ): - """'Tests the correct codes are returned for diagnostics""" + """Tests the correct codes are returned for diagnostics""" test_cases = [ { "diagnostics": DiagnosticsDictionaries.RESOURCE_FOUND_ERROR, @@ -221,6 +222,7 @@ def test_splunk_logging_statuscode_diagnostics( for test_case in test_cases: with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=False), patch("common.log_decorator.logger") as mock_logger, # noqa: E999 ): # noqa: E999 result = lambda_handler( @@ -254,25 +256,20 @@ def test_splunk_logging_statuscode_diagnostics( def test_splunk_logging_multiple_rows(self): """Tests logging for multiple objects in the body of the event""" - messages = [{"row_id": "test1"}, {"row_id": "test2"}] + messages = [{"row_id": "test^1"}, {"row_id": "test^2"}] with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=False), patch("common.log_decorator.logger") as mock_logger, # noqa: E999 ): # noqa: E999 result = lambda_handler(generate_event(messages), context={}) self.assertEqual(result, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS) - expected_first_logger_info_data = { - **ValidValues.mock_message_expected_log_value, - "message_id": "test1", - } + expected_first_logger_info_data = {**ValidValues.mock_message_expected_log_value, "message_id": "test^1"} - expected_second_logger_info_data = { - **ValidValues.mock_message_expected_log_value, - "message_id": "test2", - } + expected_second_logger_info_data = {**ValidValues.mock_message_expected_log_value, "message_id": "test^2"} expected_third_logger_info_data = self.expected_lambda_handler_logs(success=True, number_of_rows=2) @@ -300,24 +297,21 @@ def test_splunk_logging_multiple_with_diagnostics( """Tests logging for multiple objects in the body of the event with diagnostics""" messages = [ { - "row_id": "test1", + "row_id": "test^1", "operation_requested": "CREATE", "diagnostics": DiagnosticsDictionaries.RESOURCE_FOUND_ERROR, }, { - "row_id": "test2", + "row_id": "test^2", "operation_requested": "UPDATE", "diagnostics": DiagnosticsDictionaries.MESSAGE_NOT_SUCCESSFUL_ERROR, }, - { - "row_id": "test3", - "operation_requested": "DELETE", - "diagnostics": DiagnosticsDictionaries.NO_PERMISSIONS, - }, + {"row_id": "test^3", "operation_requested": "DELETE", "diagnostics": DiagnosticsDictionaries.NO_PERMISSIONS}, ] with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=False), patch("common.log_decorator.logger") as mock_logger, # noqa: E999 ): # noqa: E999 result = lambda_handler(generate_event(messages), context={}) @@ -326,7 +320,7 @@ def test_splunk_logging_multiple_with_diagnostics( expected_first_logger_info_data = { **ValidValues.mock_message_expected_log_value, - "message_id": "test1", + "message_id": "test^1", "operation_requested": "CREATE", "statusCode": DiagnosticsDictionaries.RESOURCE_FOUND_ERROR["statusCode"], "status": "fail", @@ -335,7 +329,7 @@ def test_splunk_logging_multiple_with_diagnostics( expected_second_logger_info_data = { **ValidValues.mock_message_expected_log_value, - "message_id": "test2", + "message_id": "test^2", "operation_requested": "UPDATE", "statusCode": DiagnosticsDictionaries.MESSAGE_NOT_SUCCESSFUL_ERROR["statusCode"], "status": "fail", @@ -344,7 +338,7 @@ def test_splunk_logging_multiple_with_diagnostics( expected_third_logger_info_data = { **ValidValues.mock_message_expected_log_value, - "message_id": "test3", + "message_id": "test^3", "operation_requested": "DELETE", "statusCode": DiagnosticsDictionaries.NO_PERMISSIONS["statusCode"], "status": "fail", @@ -378,12 +372,13 @@ def test_splunk_update_ack_file_not_logged(self): # send 98 messages messages = [] for i in range(1, 99): - message_value = "test" + str(i) + message_value = "test^" + str(i) messages.append({"row_id": message_value}) with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 patch("common.log_decorator.logger") as mock_logger, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=False), patch( "update_ack_file.change_audit_table_status_to_processed" ) as mock_change_audit_table_status_to_processed, # noqa: E999 @@ -394,7 +389,7 @@ def test_splunk_update_ack_file_not_logged(self): expected_secondlast_logger_info_data = { **ValidValues.mock_message_expected_log_value, - "message_id": "test98", + "message_id": "test^98", } expected_last_logger_info_data = self.expected_lambda_handler_logs(success=True, number_of_rows=98) @@ -418,12 +413,13 @@ def test_splunk_update_ack_file_logged(self): # send 99 messages messages = [] for i in range(1, 100): - message_value = "test" + str(i) + message_value = "test^" + str(i) messages.append({"row_id": message_value}) with ( # noqa: E999 patch("common.log_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 patch("common.log_decorator.logger") as mock_logger, # noqa: E999 + patch("ack_processor.is_ack_processing_complete", return_value=True), patch( "update_ack_file.change_audit_table_status_to_processed" ) as mock_change_audit_table_status_to_processed, # noqa: E999 @@ -434,11 +430,11 @@ def test_splunk_update_ack_file_logged(self): expected_thirdlast_logger_info_data = { **ValidValues.mock_message_expected_log_value, - "message_id": "test99", + "message_id": "test^99", } expected_secondlast_logger_info_data = { **ValidValues.upload_ack_file_expected_log, - "message_id": "test1", + "message_id": "test", "time_taken": "1.0s", } expected_last_logger_info_data = self.expected_lambda_handler_logs( diff --git a/lambdas/ack_backend/tests/test_update_ack_file.py b/lambdas/ack_backend/tests/test_update_ack_file.py index c9ee7de7a..8cea1201e 100644 --- a/lambdas/ack_backend/tests/test_update_ack_file.py +++ b/lambdas/ack_backend/tests/test_update_ack_file.py @@ -3,19 +3,19 @@ import unittest import os from boto3 import client as boto3_client -from moto import mock_s3 +from moto import mock_aws -from tests.utils.values_for_ack_backend_tests import ValidValues, DefaultValues -from tests.utils.mock_environment_variables import ( +from utils.values_for_ack_backend_tests import ValidValues, DefaultValues +from utils.mock_environment_variables import ( MOCK_ENVIRONMENT_DICT, BucketNames, REGION_NAME, ) -from tests.utils.generic_setup_and_teardown_for_ack_backend import ( +from utils.generic_setup_and_teardown_for_ack_backend import ( GenericSetUp, GenericTearDown, ) -from tests.utils.utils_for_ack_backend_tests import ( +from utils.utils_for_ack_backend_tests import ( setup_existing_ack_file, obtain_current_ack_file_content, generate_expected_ack_file_row, @@ -38,7 +38,7 @@ @patch.dict(os.environ, MOCK_ENVIRONMENT_DICT) -@mock_s3 +@mock_aws class TestUpdateAckFile(unittest.TestCase): """Tests for the functions in the update_ack_file module.""" @@ -131,9 +131,6 @@ def test_update_ack_file(self): with self.subTest(test_case["description"]): update_ack_file( file_key=MOCK_MESSAGE_DETAILS.file_key, - message_id=MOCK_MESSAGE_DETAILS.message_id, - supplier=MOCK_MESSAGE_DETAILS.supplier, - vaccine_type=MOCK_MESSAGE_DETAILS.vaccine_type, created_at_formatted_string=MOCK_MESSAGE_DETAILS.created_at_formatted_string, ack_data_rows=test_case["input_rows"], ) @@ -159,9 +156,6 @@ def test_update_ack_file_existing(self): ] update_ack_file( file_key=MOCK_MESSAGE_DETAILS.file_key, - message_id=MOCK_MESSAGE_DETAILS.message_id, - supplier=MOCK_MESSAGE_DETAILS.supplier, - vaccine_type=MOCK_MESSAGE_DETAILS.vaccine_type, created_at_formatted_string=MOCK_MESSAGE_DETAILS.created_at_formatted_string, ack_data_rows=ack_data_rows, ) diff --git a/lambdas/ack_backend/tests/test_update_ack_file_flow.py b/lambdas/ack_backend/tests/test_update_ack_file_flow.py index 6ad16c148..73b3ba273 100644 --- a/lambdas/ack_backend/tests/test_update_ack_file_flow.py +++ b/lambdas/ack_backend/tests/test_update_ack_file_flow.py @@ -1,17 +1,14 @@ from unittest.mock import patch -from io import StringIO - import update_ack_file import unittest import boto3 -from moto import mock_s3 +from moto import mock_aws -@mock_s3 +@mock_aws class TestUpdateAckFileFlow(unittest.TestCase): def setUp(self): - # Patch all AWS and external dependencies self.s3_client = boto3.client("s3", region_name="eu-west-2") self.ack_bucket_name = "my-ack-bucket" @@ -37,45 +34,38 @@ def setUp(self): self.logger_patcher = patch("update_ack_file.logger") self.mock_logger = self.logger_patcher.start() - self.get_row_count_patcher = patch("update_ack_file.get_row_count") - self.mock_get_row_count = self.get_row_count_patcher.start() - self.change_audit_status_patcher = patch("update_ack_file.change_audit_table_status_to_processed") self.mock_change_audit_status = self.change_audit_status_patcher.start() def tearDown(self): self.logger_patcher.stop() - self.get_row_count_patcher.stop() self.change_audit_status_patcher.stop() - def test_audit_table_updated_correctly(self): + def test_audit_table_updated_correctly_when_ack_process_complete(self): """VED-167 - Test that the audit table has been updated correctly""" # Setup - self.mock_get_row_count.side_effect = [3, 3] - accumulated_csv_content = StringIO("header1|header2\n") - ack_data_rows = [ - {"a": 1, "b": 2, "row": "audit-test-1"}, - {"a": 3, "b": 4, "row": "audit-test-2"}, - {"a": 5, "b": 6, "row": "audit-test-3"}, - ] message_id = "msg-audit-table" + mock_created_at_string = "created_at_formatted_string" file_key = "audit_table_test.csv" self.s3_client.put_object( Bucket=self.source_bucket_name, Key=f"processing/{file_key}", Body="dummy content", ) + self.s3_client.put_object( + Bucket=self.ack_bucket_name, Key=f"TempAck/audit_table_test_BusAck_{mock_created_at_string}.csv" + ) + # Act - update_ack_file.upload_ack_file( - temp_ack_file_key=f"TempAck/{file_key}", + update_ack_file.complete_batch_file_process( message_id=message_id, supplier="queue-audit-table-supplier", vaccine_type="vaccine-type", - accumulated_csv_content=accumulated_csv_content, - ack_data_rows=ack_data_rows, - archive_ack_file_key=f"forwardedFile/{file_key}", + created_at_formatted_string=mock_created_at_string, file_key=file_key, + total_ack_rows_processed=3, ) + # Assert: Only check audit table update self.mock_change_audit_status.assert_called_once_with(file_key, message_id) diff --git a/lambdas/ack_backend/tests/utils/__init__.py b/lambdas/ack_backend/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lambdas/ack_backend/tests/utils/generic_setup_and_teardown_for_ack_backend.py b/lambdas/ack_backend/tests/utils/generic_setup_and_teardown_for_ack_backend.py index f30dc8244..e35c8d863 100644 --- a/lambdas/ack_backend/tests/utils/generic_setup_and_teardown_for_ack_backend.py +++ b/lambdas/ack_backend/tests/utils/generic_setup_and_teardown_for_ack_backend.py @@ -1,6 +1,8 @@ """Generic setup and teardown for ACK backend tests""" -from tests.utils.mock_environment_variables import BucketNames, Firehose, REGION_NAME +from tests.utils.mock_environment_variables import AUDIT_TABLE_NAME, BucketNames, Firehose, REGION_NAME + +from constants import AuditTableKeys class GenericSetUp: @@ -11,7 +13,7 @@ class GenericSetUp: * If firehose_client is provided, creates a firehose delivery stream """ - def __init__(self, s3_client=None, firehose_client=None): + def __init__(self, s3_client=None, firehose_client=None, dynamodb_client=None): if s3_client: for bucket_name in [ @@ -35,11 +37,19 @@ def __init__(self, s3_client=None, firehose_client=None): }, ) + if dynamodb_client: + dynamodb_client.create_table( + TableName=AUDIT_TABLE_NAME, + KeySchema=[{"AttributeName": AuditTableKeys.MESSAGE_ID, "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": AuditTableKeys.MESSAGE_ID, "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + class GenericTearDown: """Performs generic tear down of mock resources""" - def __init__(self, s3_client=None, firehose_client=None): + def __init__(self, s3_client=None, firehose_client=None, dynamodb_client=None): if s3_client: for bucket_name in [ @@ -53,3 +63,6 @@ def __init__(self, s3_client=None, firehose_client=None): if firehose_client: firehose_client.delete_delivery_stream(DeliveryStreamName=Firehose.STREAM_NAME) + + if dynamodb_client: + dynamodb_client.delete_table(TableName=AUDIT_TABLE_NAME) diff --git a/lambdas/ack_backend/tests/utils/mock_environment_variables.py b/lambdas/ack_backend/tests/utils/mock_environment_variables.py index 1a0d9641b..267051df3 100644 --- a/lambdas/ack_backend/tests/utils/mock_environment_variables.py +++ b/lambdas/ack_backend/tests/utils/mock_environment_variables.py @@ -1,6 +1,7 @@ """Module containing mock environment variables for use in ack backend tests""" REGION_NAME = "eu-west-2" +AUDIT_TABLE_NAME = "immunisation-batch-internal-dev-audit-table" class BucketNames: @@ -20,6 +21,6 @@ class Firehose: MOCK_ENVIRONMENT_DICT = { "ACK_BUCKET_NAME": BucketNames.DESTINATION, "FIREHOSE_STREAM_NAME": Firehose.STREAM_NAME, - "AUDIT_TABLE_NAME": "immunisation-batch-internal-dev-audit-table", + "AUDIT_TABLE_NAME": AUDIT_TABLE_NAME, "SOURCE_BUCKET_NAME": BucketNames.SOURCE, } diff --git a/lambdas/ack_backend/tests/utils/utils_for_ack_backend_tests.py b/lambdas/ack_backend/tests/utils/utils_for_ack_backend_tests.py index f2153743b..5d8af01cd 100644 --- a/lambdas/ack_backend/tests/utils/utils_for_ack_backend_tests.py +++ b/lambdas/ack_backend/tests/utils/utils_for_ack_backend_tests.py @@ -1,13 +1,24 @@ """Utils functions for the ack backend tests""" import json +from typing import Optional from boto3 import client as boto3_client from tests.utils.values_for_ack_backend_tests import ValidValues, MOCK_MESSAGE_DETAILS -from tests.utils.mock_environment_variables import REGION_NAME, BucketNames +from tests.utils.mock_environment_variables import AUDIT_TABLE_NAME, REGION_NAME, BucketNames firehose_client = boto3_client("firehose", region_name=REGION_NAME) +def add_audit_entry_to_table(dynamodb_client, batch_event_message_id: str, record_count: Optional[int] = None) -> None: + """Add an entry to the audit table""" + audit_table_entry = {"status": {"S": "Preprocessed"}, "message_id": {"S": batch_event_message_id}} + + if record_count is not None: + audit_table_entry["record_count"] = {"N": str(record_count)} + + dynamodb_client.put_item(TableName=AUDIT_TABLE_NAME, Item=audit_table_entry) + + def generate_event(test_messages: list[dict]) -> dict: """ Returns an event where each message in the incoming message body list is based on a standard mock message, @@ -35,6 +46,14 @@ def obtain_current_ack_file_content(s3_client, temp_ack_file_key: str = MOCK_MES return retrieved_object["Body"].read().decode("utf-8") +def obtain_completed_ack_file_content( + s3_client, complete_ack_file_key: str = MOCK_MESSAGE_DETAILS.archive_ack_file_key +) -> str: + """Obtains the ack file content from the forwardedFile directory""" + retrieved_object = s3_client.get_object(Bucket=BucketNames.DESTINATION, Key=complete_ack_file_key) + return retrieved_object["Body"].read().decode("utf-8") + + def generate_expected_ack_file_row( success: bool, imms_id: str = MOCK_MESSAGE_DETAILS.imms_id, @@ -94,11 +113,14 @@ def validate_ack_file_content( s3_client, incoming_messages: list[dict], existing_file_content: str = ValidValues.ack_headers, + is_complete: bool = False, ) -> None: """ Obtains the ack file content and ensures that it matches the expected content (expected content is based on the incoming messages). """ - actual_ack_file_content = obtain_current_ack_file_content(s3_client) + actual_ack_file_content = ( + obtain_current_ack_file_content(s3_client) if not is_complete else (obtain_completed_ack_file_content(s3_client)) + ) expected_ack_file_content = generate_expected_ack_content(incoming_messages, existing_file_content) assert expected_ack_file_content == actual_ack_file_content diff --git a/lambdas/ack_backend/tests/utils/values_for_ack_backend_tests.py b/lambdas/ack_backend/tests/utils/values_for_ack_backend_tests.py index a3cf45ac8..9814e7d29 100644 --- a/lambdas/ack_backend/tests/utils/values_for_ack_backend_tests.py +++ b/lambdas/ack_backend/tests/utils/values_for_ack_backend_tests.py @@ -11,7 +11,7 @@ class DefaultValues: fixed_datetime_str = fixed_datetime.strftime("%Y-%m-%d %H:%M:%S") message_id = "test_file_id" - row_id = "test_file_id#1" + row_id = "test_file_id^1" local_id = "test_system_uri^testabc" imms_id = "test_imms_id" operation_requested = "CREATE" @@ -236,7 +236,7 @@ class ValidValues: ) upload_ack_file_expected_log = { - "function_name": "ack_processor_upload_ack_file", + "function_name": "ack_processor_complete_batch_file_process", "date_time": fixed_datetime.strftime("%Y-%m-%d %H:%M:%S"), "status": "success", "supplier": MOCK_MESSAGE_DETAILS.supplier, @@ -263,7 +263,7 @@ class InvalidValues: "supplier": "unknown", "file_key": "file_key_missing", "vaccine_type": "unknown", - "message_id": "unknown", + "message_id": "test^1", "operation_requested": "unknown", "time_taken": "1000.0ms", "local_id": "unknown", diff --git a/recordprocessor/src/audit_table.py b/recordprocessor/src/audit_table.py index b31c72254..e2948fdd2 100644 --- a/recordprocessor/src/audit_table.py +++ b/recordprocessor/src/audit_table.py @@ -7,16 +7,27 @@ from constants import AUDIT_TABLE_NAME, AuditTableKeys -def update_audit_table_status(file_key: str, message_id: str, status: str, error_details: Optional[str] = None) -> None: +def update_audit_table_status( + file_key: str, + message_id: str, + status: str, + error_details: Optional[str] = None, + record_count: Optional[int] = None, +) -> None: """Updates the status in the audit table to the requested value""" - update_expression = f"SET #{AuditTableKeys.STATUS} = :status" - expression_attr_names = {f"#{AuditTableKeys.STATUS}": "status"} - expression_attr_values = {":status": {"S": status}} + update_expression = f"SET #{AuditTableKeys.STATUS} = :{AuditTableKeys.STATUS}" + expression_attr_names = {f"#{AuditTableKeys.STATUS}": AuditTableKeys.STATUS} + expression_attr_values = {f":{AuditTableKeys.STATUS}": {"S": status}} + + if record_count is not None: + update_expression = update_expression + f", #{AuditTableKeys.RECORD_COUNT} = :{AuditTableKeys.RECORD_COUNT}" + expression_attr_names[f"#{AuditTableKeys.RECORD_COUNT}"] = AuditTableKeys.RECORD_COUNT + expression_attr_values[f":{AuditTableKeys.RECORD_COUNT}"] = {"N": str(record_count)} if error_details is not None: - update_expression = update_expression + f", #{AuditTableKeys.ERROR_DETAILS} = :error_details" - expression_attr_names[f"#{AuditTableKeys.ERROR_DETAILS}"] = "error_details" - expression_attr_values[":error_details"] = {"S": error_details} + update_expression = update_expression + f", #{AuditTableKeys.ERROR_DETAILS} = :{AuditTableKeys.ERROR_DETAILS}" + expression_attr_names[f"#{AuditTableKeys.ERROR_DETAILS}"] = AuditTableKeys.ERROR_DETAILS + expression_attr_values[f":{AuditTableKeys.ERROR_DETAILS}"] = {"S": error_details} try: # Update the status in the audit table to "Processed" @@ -26,7 +37,7 @@ def update_audit_table_status(file_key: str, message_id: str, status: str, error UpdateExpression=update_expression, ExpressionAttributeNames=expression_attr_names, ExpressionAttributeValues=expression_attr_values, - ConditionExpression="attribute_exists(message_id)", + ConditionExpression=f"attribute_exists({AuditTableKeys.MESSAGE_ID})", ) logger.info( diff --git a/recordprocessor/src/batch_processor.py b/recordprocessor/src/batch_processor.py index fccd9d0bc..381ed5496 100644 --- a/recordprocessor/src/batch_processor.py +++ b/recordprocessor/src/batch_processor.py @@ -96,7 +96,7 @@ def process_csv_to_fhir(incoming_message_body: dict) -> int: ) file_status = f"{FileStatus.NOT_PROCESSED} - {FileNotProcessedReason.EMPTY}" - update_audit_table_status(file_key, file_id, file_status) + update_audit_table_status(file_key, file_id, file_status, record_count=row_count) return row_count diff --git a/recordprocessor/src/constants.py b/recordprocessor/src/constants.py index ffd690af5..2e1944eea 100644 --- a/recordprocessor/src/constants.py +++ b/recordprocessor/src/constants.py @@ -75,6 +75,7 @@ class AuditTableKeys: FILENAME = "filename" MESSAGE_ID = "message_id" QUEUE_NAME = "queue_name" + RECORD_COUNT = "record_count" STATUS = "status" TIMESTAMP = "timestamp" ERROR_DETAILS = "error_details" diff --git a/recordprocessor/tests/test_process_csv_to_fhir.py b/recordprocessor/tests/test_process_csv_to_fhir.py index 3ab097e67..a2af75bc6 100644 --- a/recordprocessor/tests/test_process_csv_to_fhir.py +++ b/recordprocessor/tests/test_process_csv_to_fhir.py @@ -82,6 +82,7 @@ def test_process_csv_to_fhir_full_permissions(self): expected_table_entry = { **test_file.audit_table_entry, "status": {"S": FileStatus.PREPROCESSED}, + "record_count": {"N": "3"}, } add_entry_to_table(test_file, FileStatus.PROCESSING) self.upload_source_file( @@ -105,6 +106,7 @@ def test_process_csv_to_fhir_partial_permissions(self): expected_table_entry = { **test_file.audit_table_entry, "status": {"S": FileStatus.PREPROCESSED}, + "record_count": {"N": "3"}, } add_entry_to_table(test_file, FileStatus.PROCESSING) self.upload_source_file( @@ -125,6 +127,7 @@ def test_process_csv_to_fhir_no_permissions(self): expected_table_entry = { **test_file.audit_table_entry, "status": {"S": FileStatus.PREPROCESSED}, + "record_count": {"N": "2"}, } add_entry_to_table(test_file, FileStatus.PROCESSING) self.upload_source_file( diff --git a/recordprocessor/tests/test_recordprocessor_main.py b/recordprocessor/tests/test_recordprocessor_main.py index a79e6247e..5e810a62e 100644 --- a/recordprocessor/tests/test_recordprocessor_main.py +++ b/recordprocessor/tests/test_recordprocessor_main.py @@ -9,13 +9,13 @@ from moto import mock_s3, mock_kinesis, mock_firehose, mock_dynamodb from boto3 import client as boto3_client -from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( +from utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( GenericSetUp, GenericTearDown, add_entry_to_table, assert_audit_table_entry, ) -from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( +from utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( MockFileDetails, FileDetails, ValidMockFileContent, @@ -25,12 +25,12 @@ InfAckFileRows, REGION_NAME, ) -from tests.utils_for_recordprocessor_tests.mock_environment_variables import ( +from utils_for_recordprocessor_tests.mock_environment_variables import ( MOCK_ENVIRONMENT_DICT, BucketNames, Kinesis, ) -from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( +from utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( create_patch, ) @@ -230,7 +230,7 @@ def test_e2e_full_permissions(self): ] self.make_inf_ack_assertions(file_details=mock_rsv_emis_file, passed_validation=True) self.make_kinesis_assertions(assertion_cases) - assert_audit_table_entry(test_file, FileStatus.PREPROCESSED) + assert_audit_table_entry(test_file, FileStatus.PREPROCESSED, row_count=3) def test_e2e_partial_permissions(self): """ @@ -286,7 +286,7 @@ def test_e2e_partial_permissions(self): ] self.make_inf_ack_assertions(file_details=mock_rsv_emis_file, passed_validation=True) self.make_kinesis_assertions(assertion_cases) - assert_audit_table_entry(test_file, FileStatus.PREPROCESSED) + assert_audit_table_entry(test_file, FileStatus.PREPROCESSED, row_count=3) def test_e2e_no_required_permissions(self): """ @@ -307,7 +307,7 @@ def test_e2e_no_required_permissions(self): self.assertIn("diagnostics", data_dict) self.assertNotIn("fhir_json", data_dict) self.make_inf_ack_assertions(file_details=mock_rsv_emis_file, passed_validation=True) - assert_audit_table_entry(test_file, FileStatus.PREPROCESSED) + assert_audit_table_entry(test_file, FileStatus.PREPROCESSED, row_count=2) def test_e2e_no_permissions(self): """ @@ -511,7 +511,7 @@ def test_e2e_empty_file_is_flagged_and_processed_correctly(self): "RSV_Vaccinations_v5_8HK48_20210730T12000000.csv", ) self.assertListEqual(kinesis_records, []) - assert_audit_table_entry(test_file, "Not processed - Empty file") + assert_audit_table_entry(test_file, "Not processed - Empty file", row_count=0) self.assert_object_moved_to_archive(test_file.file_key) def test_e2e_error_is_logged_if_invalid_json_provided(self): diff --git a/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py b/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py index d36a7c1e4..883fc8146 100644 --- a/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py +++ b/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py @@ -1,21 +1,22 @@ """Utils for the recordprocessor tests""" from io import StringIO -from tests.utils_for_recordprocessor_tests.mock_environment_variables import ( +from utils_for_recordprocessor_tests.mock_environment_variables import ( BucketNames, Firehose, Kinesis, ) -from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( +from utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( MockFileDetails, FileDetails, ) from boto3.dynamodb.types import TypeDeserializer from boto3 import client as boto3_client from unittest.mock import patch -from tests.utils_for_recordprocessor_tests.mock_environment_variables import ( +from utils_for_recordprocessor_tests.mock_environment_variables import ( MOCK_ENVIRONMENT_DICT, ) +from typing import Optional # Ensure environment variables are mocked before importing from src files with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT): @@ -24,7 +25,6 @@ from constants import ( AuditTableKeys, AUDIT_TABLE_NAME, - FileStatus, AUDIT_TABLE_FILENAME_GSI, AUDIT_TABLE_QUEUE_NAME_GSI, ) @@ -164,7 +164,7 @@ def __init__( dynamo_db_client.delete_table(TableName=AUDIT_TABLE_NAME) -def add_entry_to_table(file_details: MockFileDetails, file_status: FileStatus) -> None: +def add_entry_to_table(file_details: MockFileDetails, file_status: str) -> None: """Add an entry to the audit table""" audit_table_entry = {**file_details.audit_table_entry, "status": {"S": file_status}} dynamodb_client.put_item(TableName=AUDIT_TABLE_NAME, Item=audit_table_entry) @@ -178,16 +178,18 @@ def deserialize_dynamodb_types(dynamodb_table_entry_with_types): return {k: TypeDeserializer().deserialize(v) for k, v in dynamodb_table_entry_with_types.items()} -def assert_audit_table_entry(file_details: FileDetails, expected_status: FileStatus) -> None: +def assert_audit_table_entry(file_details: FileDetails, expected_status: str, row_count: Optional[int] = None) -> None: """Assert that the file details are in the audit table""" table_entry = dynamodb_client.get_item( TableName=AUDIT_TABLE_NAME, Key={AuditTableKeys.MESSAGE_ID: {"S": file_details.message_id}}, ).get("Item") - assert table_entry == { - **file_details.audit_table_entry, - "status": {"S": expected_status}, - } + expected_result = {**file_details.audit_table_entry, "status": {"S": expected_status}} + + if row_count is not None: + expected_result["record_count"] = {"N": str(row_count)} + + assert table_entry == expected_result def create_patch(target: str): diff --git a/terraform/ack_lambda.tf b/terraform/ack_lambda.tf index 2b12e6e8d..865cd3077 100644 --- a/terraform/ack_lambda.tf +++ b/terraform/ack_lambda.tf @@ -130,6 +130,7 @@ resource "aws_iam_policy" "ack_lambda_exec_policy" { { Effect = "Allow" Action = [ + "dynamodb:GetItem", "dynamodb:UpdateItem" ] Resource = [