Skip to content

Commit b0267fa

Browse files
committed
tests pass
1 parent fc22370 commit b0267fa

10 files changed

+141
-95
lines changed

ack_backend/src/clients.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55

66
REGION_NAME = "eu-west-2"
77

8-
s3_client = boto3_client("s3", region_name=REGION_NAME)
98
firehose_client = boto3_client("firehose", region_name=REGION_NAME)
109
lambda_client = boto3_client('lambda', region_name=REGION_NAME)
1110
dynamodb_client = boto3_client("dynamodb", region_name=REGION_NAME)
1211

1312
dynamodb_resource = boto3_resource("dynamodb", region_name=REGION_NAME)
1413

14+
s3_client = None
15+
def get_s3_client():
16+
global s3_client
17+
if s3_client is None:
18+
s3_client = boto3_client("s3", region_name=REGION_NAME)
19+
return s3_client
1520

1621
# Logger
1722
logging.basicConfig(level="INFO")

ack_backend/src/constants.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
import os
44

5-
SOURCE_BUCKET_NAME = os.getenv("SOURCE_BUCKET_NAME")
6-
ACK_BUCKET_NAME = os.getenv("ACK_BUCKET_NAME")
75
AUDIT_TABLE_NAME = os.getenv("AUDIT_TABLE_NAME")
6+
FILE_NAME_PROC_LAMBDA_NAME = os.getenv("FILE_NAME_PROC_LAMBDA_NAME")
87
AUDIT_TABLE_FILENAME_GSI = "filename_index"
98
AUDIT_TABLE_QUEUE_NAME_GSI = "queue_name_index"
10-
FILE_NAME_PROC_LAMBDA_NAME = os.getenv("FILE_NAME_PROC_LAMBDA_NAME")
119

10+
def get_source_bucket_name() -> str:
11+
"""Get the SOURCE_BUCKET_NAME environment from environment variables."""
12+
return os.getenv("SOURCE_BUCKET_NAME")
13+
14+
def get_ack_bucket_name() -> str:
15+
"""Get the ACK_BUCKET_NAME environment from environment variables."""
16+
return os.getenv("ACK_BUCKET_NAME")
1217

1318
class FileStatus:
1419
"""File status constants"""

ack_backend/src/update_ack_file.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from io import StringIO, BytesIO
55
from typing import Union
66
from botocore.exceptions import ClientError
7-
from constants import ACK_HEADERS, SOURCE_BUCKET_NAME, ACK_BUCKET_NAME, FILE_NAME_PROC_LAMBDA_NAME
7+
from constants import ACK_HEADERS, get_source_bucket_name, get_ack_bucket_name, FILE_NAME_PROC_LAMBDA_NAME
88
from audit_table import change_audit_table_status_to_processed, get_next_queued_file_details
9-
from clients import s3_client, logger, lambda_client
9+
from clients import get_s3_client, logger, lambda_client
1010
from utils_for_ack_lambda import get_row_count
1111

1212

@@ -49,7 +49,7 @@ def obtain_current_ack_content(temp_ack_file_key: str) -> StringIO:
4949
"""Returns the current ack file content if the file exists, or else initialises the content with the ack headers."""
5050
try:
5151
# If ack file exists in S3 download the contents
52-
existing_ack_file = s3_client.get_object(Bucket=ACK_BUCKET_NAME, Key=temp_ack_file_key)
52+
existing_ack_file = get_s3_client().get_object(Bucket=get_ack_bucket_name(), Key=temp_ack_file_key)
5353
existing_content = existing_ack_file["Body"].read().decode("utf-8")
5454
except ClientError as error:
5555
# If ack file does not exist in S3 create a new file containing the headers only
@@ -80,22 +80,26 @@ def upload_ack_file(
8080
cleaned_row = "|".join(data_row_str).replace(" |", "|").replace("| ", "|").strip()
8181
accumulated_csv_content.write(cleaned_row + "\n")
8282
csv_file_like_object = BytesIO(accumulated_csv_content.getvalue().encode("utf-8"))
83-
s3_client.upload_fileobj(csv_file_like_object, ACK_BUCKET_NAME, temp_ack_file_key)
8483

85-
row_count_source = get_row_count(SOURCE_BUCKET_NAME, f"processing/{file_key}")
86-
row_count_destination = get_row_count(ACK_BUCKET_NAME, temp_ack_file_key)
84+
ack_bucket_name = get_ack_bucket_name()
85+
source_bucket_name = get_source_bucket_name()
86+
87+
get_s3_client().upload_fileobj(csv_file_like_object, ack_bucket_name, temp_ack_file_key)
88+
89+
row_count_source = get_row_count(source_bucket_name, f"processing/{file_key}")
90+
row_count_destination = get_row_count(ack_bucket_name, temp_ack_file_key)
8791
# TODO: Should we check for > and if so what handling is required
8892
if row_count_destination == row_count_source:
89-
move_file(ACK_BUCKET_NAME, temp_ack_file_key, archive_ack_file_key)
90-
move_file(SOURCE_BUCKET_NAME, f"processing/{file_key}", f"archive/{file_key}")
93+
move_file(ack_bucket_name, temp_ack_file_key, archive_ack_file_key)
94+
move_file(source_bucket_name, f"processing/{file_key}", f"archive/{file_key}")
9195

9296
# Update the audit table and invoke the filename lambda with next file in the queue (if one exists)
9397
change_audit_table_status_to_processed(file_key, message_id)
9498
next_queued_file_details = get_next_queued_file_details(supplier_queue)
9599
if next_queued_file_details:
96100
invoke_filename_lambda(next_queued_file_details["filename"], next_queued_file_details["message_id"])
97101

98-
logger.info("Ack file updated to %s: %s", ACK_BUCKET_NAME, archive_ack_file_key)
102+
logger.info("Ack file updated to %s: %s", ack_bucket_name, archive_ack_file_key)
99103

100104

101105
def update_ack_file(
@@ -123,6 +127,7 @@ def update_ack_file(
123127

124128
def move_file(bucket_name: str, source_file_key: str, destination_file_key: str) -> None:
125129
"""Moves a file from one location to another within a single S3 bucket by copying and then deleting the file."""
130+
s3_client = get_s3_client()
126131
s3_client.copy_object(
127132
Bucket=bucket_name, CopySource={"Bucket": bucket_name, "Key": source_file_key}, Key=destination_file_key
128133
)
@@ -135,7 +140,15 @@ def invoke_filename_lambda(file_key: str, message_id: str) -> None:
135140
try:
136141
lambda_payload = {
137142
"Records": [
138-
{"s3": {"bucket": {"name": SOURCE_BUCKET_NAME}, "object": {"key": file_key}}, "message_id": message_id}
143+
{"s3":
144+
{
145+
"bucket": {
146+
"name": get_source_bucket_name()
147+
},
148+
"object": {"key": file_key}
149+
},
150+
"message_id": message_id
151+
}
139152
]
140153
}
141154
lambda_client.invoke(
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Utils for ack lambda"""
22

3-
from clients import s3_client
3+
from clients import get_s3_client
44

55

66
def get_row_count(bucket_name: str, file_key: str) -> int:
77
"""
88
Looks in the given bucket and returns the count of the number of lines in the given file.
99
NOTE: Blank lines are not included in the count.
1010
"""
11-
response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
11+
response = get_s3_client().get_object(Bucket=bucket_name, Key=file_key)
1212
return sum(1 for line in response["Body"].iter_lines() if line.strip())

ack_backend/tests/test_ack_processor.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT):
2525
from ack_processor import lambda_handler
2626

27-
s3_client = boto3_client("s3", region_name=REGION_NAME)
28-
firehose_client = boto3_client("firehose", region_name=REGION_NAME)
29-
3027
BASE_SUCCESS_MESSAGE = MOCK_MESSAGE_DETAILS.success_message
3128
BASE_FAILURE_MESSAGE = {
3229
**{k: v for k, v in BASE_SUCCESS_MESSAGE.items() if k != "imms_id"},
@@ -41,12 +38,14 @@ class TestAckProcessor(unittest.TestCase):
4138
"""Tests for the ack processor lambda handler."""
4239

4340
def setUp(self) -> None:
44-
GenericSetUp(s3_client, firehose_client)
41+
self.s3_client = boto3_client("s3", region_name=REGION_NAME)
42+
self.firehose_client = boto3_client("firehose", region_name=REGION_NAME)
43+
GenericSetUp(self.s3_client, self.firehose_client)
4544

4645
# MOCK SOURCE FILE WITH 100 ROWS TO SIMULATE THE SCENARIO WHERE THE ACK FILE IS NO FULL.
4746
# TODO: Test all other scenarios.
4847
mock_source_file_with_100_rows = StringIO("\n".join(f"Row {i}" for i in range(1, 101)))
49-
s3_client.put_object(
48+
self.s3_client.put_object(
5049
Bucket=BucketNames.SOURCE,
5150
Key=f"processing/{MOCK_MESSAGE_DETAILS.file_key}",
5251
Body=mock_source_file_with_100_rows.getvalue(),
@@ -55,7 +54,7 @@ def setUp(self) -> None:
5554
self.mock_logger_info = self.logger_info_patcher.start()
5655

5756
def tearDown(self) -> None:
58-
GenericTearDown(s3_client, firehose_client)
57+
GenericTearDown(self.s3_client, self.firehose_client)
5958
self.mock_logger_info.stop()
6059

6160
@staticmethod
@@ -114,7 +113,7 @@ def test_lambda_handler_main_multiple_records(self):
114113
response = lambda_handler(event=event, context={})
115114

116115
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
117-
validate_ack_file_content(
116+
validate_ack_file_content(self.s3_client,
118117
[*array_of_success_messages, *array_of_failure_messages, *array_of_mixed_success_and_failure_messages],
119118
existing_file_content=ValidValues.ack_headers,
120119
)
@@ -162,20 +161,20 @@ def test_lambda_handler_main(self):
162161
with self.subTest(msg=f"No existing ack file: {test_case['description']}"):
163162
response = lambda_handler(event=self.generate_event(test_case["messages"]), context={})
164163
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
165-
validate_ack_file_content(test_case["messages"])
164+
validate_ack_file_content(self.s3_client, test_case["messages"])
166165

167-
s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
166+
self.s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
168167

169168
# Test scenario where there is an existing ack file
170169
# TODO: None of the test cases have any existing ack file content?
171170
with self.subTest(msg=f"Existing ack file: {test_case['description']}"):
172171
existing_ack_file_content = test_case.get("existing_ack_file_content", "")
173-
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_ack_file_content)
172+
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_ack_file_content, self.s3_client)
174173
response = lambda_handler(event=self.generate_event(test_case["messages"]), context={})
175174
self.assertEqual(response, EXPECTED_ACK_LAMBDA_RESPONSE_FOR_SUCCESS)
176-
validate_ack_file_content(test_case["messages"], existing_ack_file_content)
175+
validate_ack_file_content(self.s3_client, test_case["messages"], existing_ack_file_content)
177176

178-
s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
177+
self.s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
179178

180179
def test_lambda_handler_error_scenarios(self):
181180
"""Test that the lambda handler raises appropriate exceptions for malformed event data."""

ack_backend/tests/test_convert_message_to_ack_row.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class TestAckProcessor(unittest.TestCase):
2727
"""Tests for the ack processor lambda handler."""
2828

2929
def setUp(self) -> None:
30-
GenericSetUp(s3_client, firehose_client)
30+
self.s3_client = boto3_client("s3", region_name=REGION_NAME)
31+
self.firehose_client = boto3_client("firehose", region_name=REGION_NAME)
32+
GenericSetUp(self.s3_client, self.firehose_client)
3133

3234
def tearDown(self) -> None:
33-
GenericTearDown(s3_client, firehose_client)
35+
GenericTearDown(self.s3_client, self.firehose_client)
3436

3537
def test_get_error_message_for_ack_file(self):
3638
"""Test the get_error_message_for_ack_file function."""

ack_backend/tests/test_splunk_logging.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,27 @@
2121
with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT):
2222
from ack_processor import lambda_handler
2323

24-
s3_client = boto3_client("s3")
25-
2624

2725
@patch.dict("os.environ", MOCK_ENVIRONMENT_DICT)
2826
@mock_s3
2927
class TestLoggingDecorators(unittest.TestCase):
3028
"""Tests for the ack lambda logging decorators"""
3129

3230
def setUp(self):
33-
GenericSetUp(s3_client)
31+
self.s3_client = boto3_client("s3", region_name="eu-west-2")
32+
GenericSetUp(self.s3_client)
3433

3534
# MOCK SOURCE FILE WITH 100 ROWS TO SIMULATE THE SCENARIO WHERE THE ACK FILE IS NO FULL.
3635
# TODO: Test all other scenarios.
3736
mock_source_file_with_100_rows = StringIO("\n".join(f"Row {i}" for i in range(1, 101)))
38-
s3_client.put_object(
37+
self.s3_client.put_object(
3938
Bucket=BucketNames.SOURCE,
4039
Key=f"processing/{ValidValues.mock_message_expected_log_value.get('file_key')}",
4140
Body=mock_source_file_with_100_rows.getvalue(),
4241
)
4342

4443
def tearDown(self):
45-
GenericTearDown(s3_client)
44+
GenericTearDown(self.s3_client)
4645

4746
def run(self, result=None):
4847
"""

ack_backend/tests/test_update_ack_file.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,28 @@
1818
generate_expected_ack_content,
1919
MOCK_MESSAGE_DETAILS,
2020
)
21-
from constants import ACK_HEADERS, SOURCE_BUCKET_NAME, ACK_BUCKET_NAME, FILE_NAME_PROC_LAMBDA_NAME
2221

2322
from unittest.mock import patch
2423
from io import StringIO
2524

2625
with patch.dict("os.environ", MOCK_ENVIRONMENT_DICT):
2726
from update_ack_file import obtain_current_ack_content, create_ack_data, update_ack_file
2827

29-
s3_client = boto3_client("s3", region_name=REGION_NAME)
3028
firehose_client = boto3_client("firehose", region_name=REGION_NAME)
3129

3230

3331
@patch.dict(os.environ, MOCK_ENVIRONMENT_DICT)
3432
@mock_s3
3533
class TestUpdateAckFile(unittest.TestCase):
3634
"""Tests for the functions in the update_ack_file module."""
37-
3835
def setUp(self) -> None:
39-
GenericSetUp(s3_client)
36+
self.s3_client = boto3_client("s3", region_name=REGION_NAME)
37+
GenericSetUp(self.s3_client)
4038

4139
# MOCK SOURCE FILE WITH 100 ROWS TO SIMULATE THE SCENARIO WHERE THE ACK FILE IS NOT FULL.
4240
# TODO: Test all other scenarios.
4341
mock_source_file_with_100_rows = StringIO("\n".join(f"Row {i}" for i in range(1, 101)))
44-
s3_client.put_object(
42+
self.s3_client.put_object(
4543
Bucket=BucketNames.SOURCE,
4644
Key=f"processing/{MOCK_MESSAGE_DETAILS.file_key}",
4745
Body=mock_source_file_with_100_rows.getvalue(),
@@ -50,7 +48,7 @@ def setUp(self) -> None:
5048
self.mock_logger = self.logger_patcher.start()
5149

5250
def tearDown(self) -> None:
53-
GenericTearDown(s3_client)
51+
GenericTearDown(self.s3_client)
5452

5553
def validate_ack_file_content(
5654
self, incoming_messages: list[dict], existing_file_content: str = ValidValues.ack_headers
@@ -59,7 +57,7 @@ def validate_ack_file_content(
5957
Obtains the ack file content and ensures that it matches the expected content (expected content is based
6058
on the incoming messages).
6159
"""
62-
actual_ack_file_content = obtain_current_ack_file_content()
60+
actual_ack_file_content = obtain_current_ack_file_content(self.s3_client)
6361
expected_ack_file_content = generate_expected_ack_content(incoming_messages, existing_file_content)
6462
self.assertEqual(expected_ack_file_content, actual_ack_file_content)
6563

@@ -114,17 +112,17 @@ def test_update_ack_file(self):
114112
ack_data_rows=test_case["input_rows"],
115113
)
116114

117-
actual_ack_file_content = obtain_current_ack_file_content()
115+
actual_ack_file_content = obtain_current_ack_file_content(self.s3_client)
118116
expected_ack_file_content = ValidValues.ack_headers + "\n".join(test_case["expected_rows"]) + "\n"
119117
self.assertEqual(expected_ack_file_content, actual_ack_file_content)
120118

121-
s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
119+
self.s3_client.delete_object(Bucket=BucketNames.DESTINATION, Key=MOCK_MESSAGE_DETAILS.temp_ack_file_key)
122120

123121
def test_update_ack_file_existing(self):
124122
"""Test that update_ack_file correctly updates the ack file when there was an existing ack file"""
125123
# Mock existing content in the ack file
126124
existing_content = generate_sample_existing_ack_content()
127-
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_content)
125+
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_content, self.s3_client)
128126

129127
ack_data_rows = [ValidValues.ack_data_success_dict, ValidValues.ack_data_failure_dict]
130128
update_ack_file(
@@ -135,7 +133,7 @@ def test_update_ack_file_existing(self):
135133
ack_data_rows=ack_data_rows,
136134
)
137135

138-
actual_ack_file_content = obtain_current_ack_file_content()
136+
actual_ack_file_content = obtain_current_ack_file_content(self.s3_client)
139137
expected_rows = [
140138
generate_expected_ack_file_row(success=True, imms_id=DefaultValues.imms_id),
141139
generate_expected_ack_file_row(success=False, imms_id="", diagnostics="DIAGNOSTICS"),
@@ -205,7 +203,7 @@ def test_obtain_current_ack_content_file_no_existing(self):
205203
def test_obtain_current_ack_content_file_exists(self):
206204
"""Test that the existing ack file content is retrieved and new rows are added."""
207205
existing_content = generate_sample_existing_ack_content()
208-
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_content)
206+
setup_existing_ack_file(MOCK_MESSAGE_DETAILS.temp_ack_file_key, existing_content, self.s3_client)
209207
result = obtain_current_ack_content(MOCK_MESSAGE_DETAILS.temp_ack_file_key)
210208
self.assertEqual(result.getvalue(), existing_content)
211209

0 commit comments

Comments
 (0)