Skip to content

Commit 6ef0d2b

Browse files
committed
behaviour tests
1 parent 4b1f50d commit 6ef0d2b

14 files changed

+644
-271
lines changed

ack_backend/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ package: build
66
docker run --rm -v $(shell pwd)/build:/build ack-lambda-build
77

88
test:
9-
python -m unittest
9+
@PYTHONPATH=src:tests python -m unittest
1010

1111
.PHONY: build package

ack_backend/poetry.lock

Lines changed: 226 additions & 218 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ack_backend/src/clients.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55

66
REGION_NAME = "eu-west-2"
77

8-
s3_client = boto3_client("s3", region_name=REGION_NAME)
98
firehose_client = boto3_client("firehose", region_name=REGION_NAME)
109
lambda_client = boto3_client('lambda', region_name=REGION_NAME)
1110
dynamodb_client = boto3_client("dynamodb", region_name=REGION_NAME)
1211

1312
dynamodb_resource = boto3_resource("dynamodb", region_name=REGION_NAME)
1413

14+
s3_client = None
15+
def get_s3_client():
16+
global s3_client
17+
if s3_client is None:
18+
s3_client = boto3_client("s3", region_name=REGION_NAME)
19+
return s3_client
1520

1621
# Logger
1722
logging.basicConfig(level="INFO")

ack_backend/src/constants.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
import os
44

5-
SOURCE_BUCKET_NAME = os.getenv("SOURCE_BUCKET_NAME")
6-
ACK_BUCKET_NAME = os.getenv("ACK_BUCKET_NAME")
75
AUDIT_TABLE_NAME = os.getenv("AUDIT_TABLE_NAME")
6+
FILE_NAME_PROC_LAMBDA_NAME = os.getenv("FILE_NAME_PROC_LAMBDA_NAME")
87
AUDIT_TABLE_FILENAME_GSI = "filename_index"
98
AUDIT_TABLE_QUEUE_NAME_GSI = "queue_name_index"
10-
FILE_NAME_PROC_LAMBDA_NAME = os.getenv("FILE_NAME_PROC_LAMBDA_NAME")
119

10+
def get_source_bucket_name() -> str:
11+
"""Get the SOURCE_BUCKET_NAME environment from environment variables."""
12+
return os.getenv("SOURCE_BUCKET_NAME")
13+
14+
def get_ack_bucket_name() -> str:
15+
"""Get the ACK_BUCKET_NAME environment from environment variables."""
16+
return os.getenv("ACK_BUCKET_NAME")
1217

1318
class FileStatus:
1419
"""File status constants"""

ack_backend/src/update_ack_file.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from io import StringIO, BytesIO
55
from typing import Union
66
from botocore.exceptions import ClientError
7-
from constants import ACK_HEADERS, SOURCE_BUCKET_NAME, ACK_BUCKET_NAME, FILE_NAME_PROC_LAMBDA_NAME
7+
from constants import ACK_HEADERS, get_source_bucket_name, get_ack_bucket_name, FILE_NAME_PROC_LAMBDA_NAME
88
from audit_table import change_audit_table_status_to_processed, get_next_queued_file_details
9-
from clients import s3_client, logger, lambda_client
9+
from clients import get_s3_client, logger, lambda_client
1010
from utils_for_ack_lambda import get_row_count
1111

1212

@@ -49,7 +49,7 @@ def obtain_current_ack_content(temp_ack_file_key: str) -> StringIO:
4949
"""Returns the current ack file content if the file exists, or else initialises the content with the ack headers."""
5050
try:
5151
# If ack file exists in S3 download the contents
52-
existing_ack_file = s3_client.get_object(Bucket=ACK_BUCKET_NAME, Key=temp_ack_file_key)
52+
existing_ack_file = get_s3_client().get_object(Bucket=get_ack_bucket_name(), Key=temp_ack_file_key)
5353
existing_content = existing_ack_file["Body"].read().decode("utf-8")
5454
except ClientError as error:
5555
# If ack file does not exist in S3 create a new file containing the headers only
@@ -80,22 +80,26 @@ def upload_ack_file(
8080
cleaned_row = "|".join(data_row_str).replace(" |", "|").replace("| ", "|").strip()
8181
accumulated_csv_content.write(cleaned_row + "\n")
8282
csv_file_like_object = BytesIO(accumulated_csv_content.getvalue().encode("utf-8"))
83-
s3_client.upload_fileobj(csv_file_like_object, ACK_BUCKET_NAME, temp_ack_file_key)
8483

85-
row_count_source = get_row_count(SOURCE_BUCKET_NAME, f"processing/{file_key}")
86-
row_count_destination = get_row_count(ACK_BUCKET_NAME, temp_ack_file_key)
84+
ack_bucket_name = get_ack_bucket_name()
85+
source_bucket_name = get_source_bucket_name()
86+
87+
get_s3_client().upload_fileobj(csv_file_like_object, ack_bucket_name, temp_ack_file_key)
88+
89+
row_count_source = get_row_count(source_bucket_name, f"processing/{file_key}")
90+
row_count_destination = get_row_count(ack_bucket_name, temp_ack_file_key)
8791
# TODO: Should we check for > and if so what handling is required
8892
if row_count_destination == row_count_source:
89-
move_file(ACK_BUCKET_NAME, temp_ack_file_key, archive_ack_file_key)
90-
move_file(SOURCE_BUCKET_NAME, f"processing/{file_key}", f"archive/{file_key}")
93+
move_file(ack_bucket_name, temp_ack_file_key, archive_ack_file_key)
94+
move_file(source_bucket_name, f"processing/{file_key}", f"archive/{file_key}")
9195

9296
# Update the audit table and invoke the filename lambda with next file in the queue (if one exists)
9397
change_audit_table_status_to_processed(file_key, message_id)
9498
next_queued_file_details = get_next_queued_file_details(supplier_queue)
9599
if next_queued_file_details:
96100
invoke_filename_lambda(next_queued_file_details["filename"], next_queued_file_details["message_id"])
97101

98-
logger.info("Ack file updated to %s: %s", ACK_BUCKET_NAME, archive_ack_file_key)
102+
logger.info("Ack file updated to %s: %s", ack_bucket_name, archive_ack_file_key)
99103

100104

101105
def update_ack_file(
@@ -123,6 +127,7 @@ def update_ack_file(
123127

124128
def move_file(bucket_name: str, source_file_key: str, destination_file_key: str) -> None:
125129
"""Moves a file from one location to another within a single S3 bucket by copying and then deleting the file."""
130+
s3_client = get_s3_client()
126131
s3_client.copy_object(
127132
Bucket=bucket_name, CopySource={"Bucket": bucket_name, "Key": source_file_key}, Key=destination_file_key
128133
)
@@ -135,7 +140,15 @@ def invoke_filename_lambda(file_key: str, message_id: str) -> None:
135140
try:
136141
lambda_payload = {
137142
"Records": [
138-
{"s3": {"bucket": {"name": SOURCE_BUCKET_NAME}, "object": {"key": file_key}}, "message_id": message_id}
143+
{"s3":
144+
{
145+
"bucket": {
146+
"name": get_source_bucket_name()
147+
},
148+
"object": {"key": file_key}
149+
},
150+
"message_id": message_id
151+
}
139152
]
140153
}
141154
lambda_client.invoke(
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Utils for ack lambda"""
22

3-
from clients import s3_client
3+
from clients import get_s3_client
44

55

66
def get_row_count(bucket_name: str, file_key: str) -> int:
77
"""
88
Looks in the given bucket and returns the count of the number of lines in the given file.
99
NOTE: Blank lines are not included in the count.
1010
"""
11-
response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
11+
response = get_s3_client().get_object(Bucket=bucket_name, Key=file_key)
1212
return sum(1 for line in response["Body"].iter_lines() if line.strip())

ack_backend/tests/test_ack_processor.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT):
2525
from ack_processor import lambda_handler
2626

27-
s3_client = boto3_client("s3", region_name=REGION_NAME)
28-
firehose_client = boto3_client("firehose", region_name=REGION_NAME)
29-
3027
BASE_SUCCESS_MESSAGE = MOCK_MESSAGE_DETAILS.success_message
3128
BASE_FAILURE_MESSAGE = {
3229
**{k: v for k, v in BASE_SUCCESS_MESSAGE.items() if k != "imms_id"},
@@ -41,19 +38,24 @@ class TestAckProcessor(unittest.TestCase):
4138
"""Tests for the ack processor lambda handler."""
4239

4340
def setUp(self) -> None:
44-
GenericSetUp(s3_client, firehose_client)
41+
self.s3_client = boto3_client("s3", region_name=REGION_NAME)
42+
self.firehose_client = boto3_client("firehose", region_name=REGION_NAME)
43+
GenericSetUp(self.s3_client, self.firehose_client)
4544

4645
# MOCK SOURCE FILE WITH 100 ROWS TO SIMULATE THE SCENARIO WHERE THE ACK FILE IS NO FULL.
4746
# TODO: Test all other scenarios.
4847
mock_source_file_with_100_rows = StringIO("\n".join(f"Row {i}" for i in range(1, 101)))
49-
s3_client.put_object(
48+
self.s3_client.put_object(
5049
Bucket=BucketNames.SOURCE,
5150
Key=f"processing/{MOCK_MESSAGE_DETAILS.file_key}",
5251
Body=mock_source_file_with_100_rows.getvalue(),
5352
)
53+
self.logger_info_patcher = patch('logging_decorators.logger.info')
54+
self.mock_logger_info = self.logger_info_patcher.start()
5455

5556
def tearDown(self) -> None:
56-
GenericTearDown(s3_client, firehose_client)
57+
GenericTearDown(self.s3_client, self.firehose_client)
58+
self.mock_logger_info.stop()
5759

5860
@staticmethod
5961
def generate_event(test_messages: list[dict]) -> dict:
@@ -111,7 +113,7 @@ def test_lambda_handler_main_multiple_records(self):
111113
response = lambda_handler(event=event, context={})
112114

113115
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
114-
validate_ack_file_content(
116+
validate_ack_file_content(self.s3_client,
115117
[*array_of_success_messages, *array_of_failure_messages, *array_of_mixed_success_and_failure_messages],
116118
existing_file_content=ValidValues.ack_headers,
117119
)
@@ -159,20 +161,20 @@ def test_lambda_handler_main(self):
159161
with self.subTest(msg=f"No existing ack file: {test_case['description']}"):
160162
response = lambda_handler(event=self.generate_event(test_case["messages"]), context={})
161163
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
162-
validate_ack_file_content(test_case["messages"])
164+
validate_ack_file_content(self.s3_client, test_case["messages"])
163165

164-
s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
166+
self.s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
165167

166168
# Test scenario where there is an existing ack file
167169
# TODO: None of the test cases have any existing ack file content?
168170
with self.subTest(msg=f"Existing ack file: {test_case['description']}"):
169171
existing_ack_file_content = test_case.get("existing_ack_file_content", "")
170-
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_ack_file_content)
172+
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_ack_file_content, self.s3_client)
171173
response = lambda_handler(event=self.generate_event(test_case["messages"]), context={})
172174
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
173-
validate_ack_file_content(test_case["messages"], existing_ack_file_content)
175+
validate_ack_file_content(self.s3_client, test_case["messages"], existing_ack_file_content)
174176

175-
s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
177+
self.s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
176178

177179
def test_lambda_handler_error_scenarios(self):
178180
"""Test that the lambda handler raises appropriate exceptions for malformed event data."""
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
import audit_table
4+
from errors import UnhandledAuditTableError
5+
6+
class TestAuditTable(unittest.TestCase):
7+
8+
def setUp(self):
9+
self.logger_patcher = patch('audit_table.logger')
10+
self.mock_logger = self.logger_patcher.start()
11+
self.dynamodb_resource_patcher = patch('audit_table.dynamodb_resource')
12+
self.mock_dynamodb_resource = self.dynamodb_resource_patcher.start()
13+
self.dynamodb_client_patcher = patch('audit_table.dynamodb_client')
14+
self.mock_dynamodb_client = self.dynamodb_client_patcher.start()
15+
16+
def tearDown(self):
17+
self.logger_patcher.stop()
18+
self.dynamodb_resource_patcher.stop()
19+
self.dynamodb_client_patcher.stop()
20+
21+
def test_get_next_queued_file_details_returns_oldest(self):
22+
# Arrange
23+
mock_table = MagicMock()
24+
self.mock_dynamodb_resource.Table.return_value = mock_table
25+
mock_table.query.return_value = {
26+
"Items": [
27+
{"timestamp": 2, "my-key": "value2"},
28+
{"timestamp": 1, "my-key": "value1"},
29+
]
30+
}
31+
# Act
32+
result = audit_table.get_next_queued_file_details("queue1")
33+
# Assert
34+
self.assertEqual(result, {"timestamp": 1, "my-key": "value1"})
35+
36+
def test_get_next_queued_file_details_returns_none_if_empty(self):
37+
mock_table = MagicMock()
38+
self.mock_dynamodb_resource.Table.return_value = mock_table
39+
mock_table.query.return_value = {"Items": []}
40+
result = audit_table.get_next_queued_file_details("queue1")
41+
self.assertIsNone(result)
42+
43+
def test_change_audit_table_status_to_processed_success(self):
44+
# Should not raise
45+
self.mock_dynamodb_client.update_item.return_value = {}
46+
audit_table.change_audit_table_status_to_processed("file1", "msg1")
47+
self.mock_dynamodb_client.update_item.assert_called_once()
48+
self.mock_logger.info.assert_called_once()
49+
50+
def test_change_audit_table_status_to_processed_raises(self):
51+
self.mock_dynamodb_client.update_item.side_effect = Exception("fail!")
52+
with self.assertRaises(UnhandledAuditTableError) as ctx:
53+
audit_table.change_audit_table_status_to_processed("file1", "msg1")
54+
self.assertIn("fail!", str(ctx.exception))
55+
self.mock_logger.error.assert_called_once()

ack_backend/tests/test_convert_message_to_ack_row.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class TestAckProcessor(unittest.TestCase):
2727
"""Tests for the ack processor lambda handler."""
2828

2929
def setUp(self) -> None:
30-
GenericSetUp(s3_client, firehose_client)
30+
self.s3_client = boto3_client("s3", region_name=REGION_NAME)
31+
self.firehose_client = boto3_client("firehose", region_name=REGION_NAME)
32+
GenericSetUp(self.s3_client, self.firehose_client)
3133

3234
def tearDown(self) -> None:
33-
GenericTearDown(s3_client, firehose_client)
35+
GenericTearDown(self.s3_client, self.firehose_client)
3436

3537
def test_get_error_message_for_ack_file(self):
3638
"""Test the get_error_message_for_ack_file function."""

0 commit comments

Comments
 (0)