Skip to content

Commit f60f7fa

Browse files
committed
init
1 parent d569421 commit f60f7fa

File tree

11 files changed

+48
-38
lines changed

11 files changed

+48
-38
lines changed

lambdas/filenameprocessor/src/file_name_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from uuid import uuid4
1111

1212
from audit_table import upsert_audit_table
13-
from common.clients import STREAM_NAME, logger, s3_client
13+
from common.clients import STREAM_NAME, logger, get_s3_client
1414
from common.log_decorator import logging_decorator
1515
from common.models.errors import (
1616
InvalidFileKeyError,
@@ -73,7 +73,7 @@ def handle_record(record) -> dict:
7373

7474
try:
7575
message_id = str(uuid4())
76-
s3_response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
76+
s3_response = get_s3_client().get_object(Bucket=bucket_name, Key=file_key)
7777
created_at_formatted_string, expiry_timestamp = get_creation_and_expiry_times(s3_response)
7878

7979
vaccine_type, supplier = validate_file_key(file_key)

lambdas/filenameprocessor/src/make_and_upload_ack_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from csv import writer
55
from io import BytesIO, StringIO
66

7-
from common.clients import s3_client
7+
from common.clients import get_s3_client
88

99

1010
def make_the_ack_data(message_id: str, message_delivered: bool, created_at_formatted_string: str) -> dict:
@@ -43,7 +43,7 @@ def upload_ack_file(file_key: str, ack_data: dict, created_at_formatted_string:
4343
csv_buffer.seek(0)
4444
csv_bytes = BytesIO(csv_buffer.getvalue().encode("utf-8"))
4545
ack_bucket_name = os.getenv("ACK_BUCKET_NAME")
46-
s3_client.upload_fileobj(csv_bytes, ack_bucket_name, ack_filename)
46+
get_s3_client().upload_fileobj(csv_bytes, ack_bucket_name, ack_filename)
4747

4848

4949
def make_and_upload_the_ack_file(

lambdas/filenameprocessor/src/utils_for_filenameprocessor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from datetime import timedelta
44

5-
from common.clients import logger, s3_client
5+
from common.clients import logger, get_s3_client
66
from constants import AUDIT_TABLE_TTL_DAYS
77

88

@@ -16,6 +16,7 @@ def get_creation_and_expiry_times(s3_response: dict) -> (str, int):
1616

1717
def move_file(bucket_name: str, source_file_key: str, destination_file_key: str) -> None:
1818
"""Moves a file from one location to another within a single S3 bucket by copying and then deleting the file."""
19+
s3_client = get_s3_client()
1920
s3_client.copy_object(
2021
Bucket=bucket_name,
2122
CopySource={"Bucket": bucket_name, "Key": source_file_key},

lambdas/recordprocessor/src/file_level_validation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from csv import DictReader
77

88
from audit_table import update_audit_table_status
9-
from common.clients import logger, s3_client
9+
from common.clients import logger, get_s3_client
1010
from constants import (
1111
ARCHIVE_DIR_NAME,
1212
EXPECTED_CSV_HEADERS,
@@ -63,12 +63,13 @@ def get_permitted_operations(supplier: str, vaccine_type: str, allowed_permissio
6363

6464
def move_file(bucket_name: str, source_file_key: str, destination_file_key: str) -> None:
6565
"""Moves a file from one location to another within a single S3 bucket by copying and then deleting the file."""
66-
s3_client.copy_object(
66+
s3_client = get_s3_client()
67+
get_s3_client().copy_object(
6768
Bucket=bucket_name,
6869
CopySource={"Bucket": bucket_name, "Key": source_file_key},
6970
Key=destination_file_key,
7071
)
71-
s3_client.delete_object(Bucket=bucket_name, Key=source_file_key)
72+
get_s3_client().delete_object(Bucket=bucket_name, Key=source_file_key)
7273
logger.info("File moved from %s to %s", source_file_key, destination_file_key)
7374

7475

lambdas/recordprocessor/src/make_and_upload_ack_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from csv import writer
44
from io import BytesIO, StringIO
55

6-
from common.clients import s3_client
6+
from common.clients import get_s3_client
77
from constants import ACK_BUCKET_NAME
88

99

@@ -46,7 +46,7 @@ def upload_ack_file(file_key: str, ack_data: dict, created_at_formatted_string:
4646
# Upload the CSV file to S3
4747
csv_buffer.seek(0)
4848
csv_bytes = BytesIO(csv_buffer.getvalue().encode("utf-8"))
49-
s3_client.upload_fileobj(csv_bytes, ACK_BUCKET_NAME, ack_filename)
49+
get_s3_client().upload_fileobj(csv_bytes, ACK_BUCKET_NAME, ack_filename)
5050

5151

5252
def make_and_upload_ack_file(

lambdas/recordprocessor/src/utils_for_recordprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from csv import DictReader
55
from io import TextIOWrapper
66

7-
from common.clients import s3_client
7+
from common.clients import get_s3_client
88

99

1010
def get_environment() -> str:
@@ -16,7 +16,7 @@ def get_environment() -> str:
1616

1717
def get_csv_content_dict_reader(file_key: str, encoder="utf-8") -> DictReader:
1818
"""Returns the requested file contents from the source bucket in the form of a DictReader"""
19-
response = s3_client.get_object(Bucket=os.getenv("SOURCE_BUCKET_NAME"), Key=file_key)
19+
response = get_s3_client().get_object(Bucket=os.getenv("SOURCE_BUCKET_NAME"), Key=file_key)
2020
binary_io = response["Body"]
2121
text_io = TextIOWrapper(binary_io, encoding=encoder, newline="")
2222
return DictReader(text_io, delimiter="|")

lambdas/recordprocessor/tests/test_recordprocessor_edge_cases.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22
import unittest
33
from io import BytesIO
4-
from unittest.mock import call, patch
4+
from unittest.mock import Mock, call, patch
55

66
from batch_processor import process_csv_to_fhir
77
from utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import (
8+
BucketNames,
89
create_patch,
910
)
1011

@@ -16,8 +17,7 @@ def setUp(self):
1617
self.mock_logger_error = create_patch("logging.Logger.error")
1718
self.mock_send_to_kinesis = create_patch("batch_processor.send_to_kinesis")
1819
self.mock_map_target_disease = create_patch("batch_processor.map_target_disease")
19-
self.mock_s3_get_object = create_patch("utils_for_recordprocessor.s3_client.get_object")
20-
self.mock_s3_put_object = create_patch("utils_for_recordprocessor.s3_client.put_object")
20+
self.mock_get_s3_client = create_patch("utils_for_recordprocessor.get_s3_client")
2121
self.mock_make_and_move = create_patch("file_level_validation.make_and_upload_ack_file")
2222
self.mock_move_file = create_patch("file_level_validation.move_file")
2323
self.mock_get_permitted_operations = create_patch("file_level_validation.get_permitted_operations")
@@ -63,7 +63,9 @@ def test_process_large_file_cp1252(self):
6363
data = self.insert_cp1252_at_end(data, b"D\xe9cembre", 2)
6464
ret1 = {"Body": BytesIO(b"".join(data))}
6565
ret2 = {"Body": BytesIO(b"".join(data))}
66-
self.mock_s3_get_object.side_effect = [ret1, ret2]
66+
mock_s3 = Mock()
67+
mock_s3.get_object.side_effect = [ret1, ret2]
68+
self.mock_get_s3_client.return_value = mock_s3
6769
self.mock_map_target_disease.return_value = "some disease"
6870

6971
message_body = {
@@ -80,10 +82,11 @@ def test_process_large_file_cp1252(self):
8082
self.mock_logger_warning.assert_called()
8183
warning_call_args = self.mock_logger_warning.call_args[0][0]
8284
self.assertTrue(warning_call_args.startswith("Encoding Error: 'utf-8' codec can't decode byte 0xe9"))
83-
self.mock_s3_get_object.assert_has_calls(
85+
# TODO: when running all tests this expects Bucket=None. not clear why.
86+
mock_s3.get_object.assert_has_calls(
8487
[
85-
call(Bucket=None, Key="test-filename"),
86-
call(Bucket=None, Key="processing/test-filename"),
88+
call(Bucket=BucketNames.SOURCE, Key="test-filename"),
89+
call(Bucket=BucketNames.SOURCE, Key="processing/test-filename"),
8790
]
8891
)
8992

@@ -94,7 +97,9 @@ def test_process_large_file_utf8(self):
9497
data = self.expand_test_data(data, n_rows)
9598
ret1 = {"Body": BytesIO(b"".join(data))}
9699
ret2 = {"Body": BytesIO(b"".join(data))}
97-
self.mock_s3_get_object.side_effect = [ret1, ret2]
100+
mock_s3 = Mock()
101+
mock_s3.get_object.side_effect = [ret1, ret2]
102+
self.mock_get_s3_client.return_value = mock_s3
98103
self.mock_map_target_disease.return_value = "some disease"
99104

100105
message_body = {
@@ -118,7 +123,9 @@ def test_process_small_file_cp1252(self):
118123

119124
ret1 = {"Body": BytesIO(b"".join(data))}
120125
ret2 = {"Body": BytesIO(b"".join(data))}
121-
self.mock_s3_get_object.side_effect = [ret1, ret2]
126+
mock_s3 = Mock()
127+
mock_s3.get_object.side_effect = [ret1, ret2]
128+
self.mock_get_s3_client.return_value = mock_s3
122129
self.mock_map_target_disease.return_value = "some disease"
123130

124131
message_body = {
@@ -143,7 +150,9 @@ def test_process_small_file_utf8(self):
143150

144151
ret1 = {"Body": BytesIO(b"".join(data))}
145152
ret2 = {"Body": BytesIO(b"".join(data))}
146-
self.mock_s3_get_object.side_effect = [ret1, ret2]
153+
mock_s3 = Mock()
154+
mock_s3.get_object.side_effect = [ret1, ret2]
155+
self.mock_get_s3_client.return_value = mock_s3
147156
self.mock_map_target_disease.return_value = "some disease"
148157

149158
message_body = {

lambdas/shared/src/common/clients.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
REGION_NAME = os.getenv("AWS_REGION", "eu-west-2")
1616

17-
s3_client = boto3_client("s3", region_name=REGION_NAME)
18-
19-
# for lambdas which require a global s3_client
2017
global_s3_client = None
2118

2219

lambdas/shared/src/common/s3_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from common.clients import logger, s3_client
1+
from common.clients import logger, get_s3_client
22

33

44
class S3Reader:
@@ -12,7 +12,7 @@ class S3Reader:
1212
@staticmethod
1313
def read(bucket_name, file_key):
1414
try:
15-
s3_file = s3_client.get_object(Bucket=bucket_name, Key=file_key)
15+
s3_file = get_s3_client().get_object(Bucket=bucket_name, Key=file_key)
1616
return s3_file["Body"].read().decode("utf-8")
1717

1818
except Exception as error: # pylint: disable=broad-except

lambdas/shared/tests/test_common/test_clients.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def test_env_variables_loaded(self):
4646

4747
def test_boto3_client_created_for_s3(self):
4848
"""Test that S3 boto3 client is created with correct region"""
49+
importlib.reload(clients)
50+
clients.get_s3_client()
4951
self.mock_boto3_client.assert_any_call("s3", region_name=self.AWS_REGION)
5052

5153
def test_boto3_client_created_for_firehose(self):

0 commit comments

Comments
 (0)