Skip to content

Commit dde853f

Browse files
committed
Merge branch 'feature/PI-702-prod_bulk_tweak' into release/2024-12-16
2 parents 19be248 + 2b89aef commit dde853f

File tree

13 files changed

+182
-84
lines changed

13 files changed

+182
-84
lines changed

infrastructure/terraform/per_workspace/modules/etl/sds/etl-diagram--bulk-transform-and-load.asl.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
},
6868
"Map": {
6969
"Type": "Map",
70+
"MaxConcurrency": 10,
7071
"ItemProcessor": {
7172
"ProcessorConfig": {
7273
"Mode": "INLINE"

scripts/etl/clear_state_inputs.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import boto3
1111
from etl_utils.constants import CHANGELOG_NUMBER, WorkerKey
1212
from etl_utils.io import pkl_dumps_lz4
13-
from sds.epr.bulk_create.bulk_load_fanout import FANOUT
1413

14+
from etl.sds.tests.etl_test_utils.etl_state import _delete_objects_by_prefix
1515
from test_helpers.aws_session import aws_session
1616
from test_helpers.terraform import read_terraform_output
1717

@@ -38,12 +38,9 @@ def main(changelog_number, workspace):
3838
s3_client.put_object(
3939
Bucket=etl_bucket, Key=WorkerKey.LOAD, Body=EMPTY_JSON_DATA
4040
)
41-
for i in range(FANOUT):
42-
s3_client.put_object(
43-
Bucket=etl_bucket,
44-
Key=f"{WorkerKey.LOAD}.{i}",
45-
Body=pkl_dumps_lz4(EMPTY_JSON_DATA),
46-
)
41+
_delete_objects_by_prefix(
42+
s3_client=s3_client, bucket=etl_bucket, key_prefix=f"{WorkerKey.LOAD}."
43+
)
4744
s3_client.delete_object(Bucket=etl_bucket, Key=CHANGELOG_NUMBER)
4845

4946
if changelog_number:

src/etl/sds/tests/etl_test_utils/etl_state.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
)
1111
from etl_utils.io import pkl_dumps_lz4
1212
from mypy_boto3_s3 import S3Client
13-
from sds.epr.bulk_create.bulk_load_fanout import FANOUT
1413

1514
from test_helpers.terraform import read_terraform_output
1615

@@ -40,10 +39,28 @@ def get_etl_config(input_filename: str, etl_type: str = "") -> EtlConfig:
4039
)
4140

4241

43-
def _delete_objects_by_prefix(s3_client: S3Client, bucket: str, key_prefix: str):
44-
response = s3_client.list_objects(Bucket=bucket, Prefix=key_prefix)
45-
for item in response.get("Contents", []):
46-
s3_client.delete_object(Bucket=bucket, Key=item["Key"])
42+
def _delete_objects_by_prefix(
43+
s3_client: S3Client, bucket: str, key_prefix: str, **kwargs
44+
):
45+
# Delete objects if any found
46+
response = s3_client.list_objects_v2(Bucket=bucket, Prefix=key_prefix, **kwargs)
47+
try:
48+
contents = response["Contents"]
49+
except KeyError:
50+
return
51+
else:
52+
for item in contents:
53+
s3_client.delete_object(Bucket=bucket, Key=item["Key"])
54+
55+
# Repeat if required
56+
continuation_token = response.get("ContinuationToken")
57+
if continuation_token:
58+
return _delete_objects_by_prefix(
59+
s3_client=s3_client,
60+
bucket=bucket,
61+
key_prefix=key_prefix,
62+
ContinuationToken=continuation_token,
63+
)
4764

4865

4966
def clear_etl_state(s3_client: S3Client, etl_config: EtlConfig):
@@ -61,12 +78,12 @@ def clear_etl_state(s3_client: S3Client, etl_config: EtlConfig):
6178
Key=WorkerKey.LOAD,
6279
Body=pkl_dumps_lz4(EMPTY_JSON_DATA),
6380
)
64-
for i in range(FANOUT):
65-
s3_client.put_object(
66-
Bucket=etl_config.bucket,
67-
Key=f"{WorkerKey.LOAD}.{i}",
68-
Body=pkl_dumps_lz4(EMPTY_JSON_DATA),
69-
)
81+
82+
# Delete load-fanout files, if they exist
83+
_delete_objects_by_prefix(
84+
s3_client=s3_client, bucket=etl_config.bucket, key_prefix=f"{WorkerKey.LOAD}."
85+
)
86+
7087
s3_client.delete_object(
7188
Bucket=etl_config.bucket, Key=etl_config.initial_trigger_key
7289
)

src/etl/sds/worker/bulk/load_bulk/tests/test_load_bulk_worker.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from event.json import json_load
2525
from moto import mock_aws
2626
from mypy_boto3_s3 import S3Client
27-
from sds.epr.bulk_create.bulk_load_fanout import FANOUT
2827
from sds.epr.constants import AS_DEVICE_SUFFIX, MHS_DEVICE_SUFFIX
2928

3029
from etl.sds.worker.bulk.tests.test_bulk_e2e import PATH_TO_STAGE_DATA
@@ -73,31 +72,33 @@ def test_load_worker_pass(
7372
from etl.sds.worker.bulk.load_bulk import load_bulk
7473

7574
# Initial state
76-
for i in range(FANOUT):
77-
with open(PATH_TO_STAGE_DATA / f"3.load_fanout_output.{i}.json") as f:
75+
input_paths = sorted(PATH_TO_STAGE_DATA.glob(f"3.load_fanout_output.*.json"))
76+
load_fanout_keys = []
77+
for i, path in enumerate(input_paths):
78+
with open(path) as f:
7879
input_data = json_load(f)
7980

81+
fanout_key = f"{WorkerKey.LOAD}.{i}"
8082
put_object(
81-
key=f"{WorkerKey.LOAD}.{i}",
83+
key=fanout_key,
8284
body=pkl_dumps_lz4(deque(input_data)),
8385
)
86+
load_fanout_keys.append(fanout_key)
8487

8588
# Execute the load worker
8689
with mock_table(TABLE_NAME) as dynamodb_client:
8790
load_bulk.CACHE.REPOSITORY.client = dynamodb_client
8891

8992
responses = []
90-
for i in range(FANOUT):
93+
for fanout_key in load_fanout_keys:
9194
response = load_bulk.handler(
92-
event={"s3_input_path": f"s3://{BUCKET_NAME}/{WorkerKey.LOAD}.{i}"},
95+
event={"s3_input_path": f"s3://{BUCKET_NAME}/{fanout_key}"},
9396
context=None,
9497
)
9598
responses.append(response)
9699

97100
# Final state
98-
final_unprocessed_data = pkl_loads_lz4(
99-
get_object(key=f"{WorkerKey.LOAD}.{i}")
100-
)
101+
final_unprocessed_data = pkl_loads_lz4(get_object(key=fanout_key))
101102
assert final_unprocessed_data == deque([])
102103

103104
assert responses == [
@@ -107,7 +108,7 @@ def test_load_worker_pass(
107108
"unprocessed_records": 0,
108109
"error_message": None,
109110
}
110-
] * (FANOUT - 1) + [
111+
] * (len(load_fanout_keys) - 1) + [
111112
{
112113
"stage_name": "load",
113114
"processed_records": 7,

src/etl/sds/worker/bulk/load_bulk_fanout/load_bulk_fanout.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from event.environment import BaseEnvironment
2424
from event.step_chain import StepChain
25-
from sds.epr.bulk_create.bulk_load_fanout import FANOUT, calculate_batch_size
2625
from sds.epr.bulk_create.bulk_repository import BulkRepository
2726

2827
if TYPE_CHECKING:
@@ -38,6 +37,7 @@ def s3_path(self, key) -> str:
3837

3938
ENVIRONMENT = TransformWorkerEnvironment.build()
4039
S3_CLIENT = boto3.client("s3")
40+
EACH_FANOUT_BATCH_SIZE = 10000
4141

4242

4343
def execute_step_chain(
@@ -72,11 +72,8 @@ def execute_step_chain(
7272

7373
count_unprocessed_records = len(action_chain.result.unprocessed_records)
7474

75-
batch_size = calculate_batch_size(
76-
sequence=action_chain.result.processed_records, n_batches=FANOUT
77-
)
7875
for i, batch in enumerate(
79-
batched(action_chain.result.processed_records, n=batch_size)
76+
batched(action_chain.result.processed_records, n=EACH_FANOUT_BATCH_SIZE)
8077
):
8178
count_processed_records = len(batch)
8279
_action_response = WorkerActionResponse(

src/etl/sds/worker/bulk/load_bulk_fanout/tests/test_load_bulk_fanout_worker.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
from collections import deque
3-
from itertools import batched
43
from pathlib import Path
54
from typing import Callable
65
from unittest import mock
@@ -17,11 +16,7 @@
1716
from event.json import json_load
1817
from moto import mock_aws
1918
from mypy_boto3_s3 import S3Client
20-
from sds.epr.bulk_create.bulk_load_fanout import (
21-
FANOUT,
22-
calculate_batch_size,
23-
count_indexes,
24-
)
19+
from sds.epr.bulk_create.bulk_load_fanout import count_indexes
2520

2621
from etl.sds.worker.bulk.tests.test_bulk_e2e import PATH_TO_STAGE_DATA
2722

@@ -78,30 +73,15 @@ def decompress(obj: dict) -> dict:
7873
return obj
7974

8075

81-
@pytest.mark.parametrize(
82-
("n_batches", "sequence_length", "expected_batch_size", "expected_n_batches"),
83-
((4, 100, 25, 4), (3, 100, 34, 3), (16, 30, 2, 15)),
84-
)
85-
def test_calculate_batch_size_general(
86-
n_batches: int,
87-
sequence_length: int,
88-
expected_batch_size: int,
89-
expected_n_batches: int,
90-
):
91-
n_batches = n_batches
92-
sequence = list(range(sequence_length))
93-
batch_size = calculate_batch_size(sequence, n_batches)
94-
assert batch_size == expected_batch_size
95-
96-
batches = list(batched(sequence, batch_size))
97-
assert len(batches) == expected_n_batches
98-
99-
10076
def test_load_worker_fanout(
101-
put_object: Callable[[str], None], get_object: Callable[[str], bytes]
77+
put_object: Callable[[str], None],
78+
get_object: Callable[[str], bytes],
10279
):
80+
_EACH_FANOUT_BATCH_SIZE = 10
10381
from etl.sds.worker.bulk.load_bulk_fanout import load_bulk_fanout
10482

83+
load_bulk_fanout.EACH_FANOUT_BATCH_SIZE = _EACH_FANOUT_BATCH_SIZE
84+
10585
# Initial state
10686
with open(PATH_TO_STAGE_DATA / "2.transform_output.json") as f:
10787
input_data: list[dict[str, dict]] = json_load(f)
@@ -114,31 +94,36 @@ def test_load_worker_fanout(
11494
# Execute the load worker
11595
responses = load_bulk_fanout.handler(event={}, context=None)
11696

117-
assert len(responses) == FANOUT
118-
assert responses == [
97+
*head_responses, tail_response = responses
98+
99+
assert len(head_responses) > 1
100+
101+
expected_head_responses = [
119102
{
120103
"stage_name": "load_bulk_fanout",
121-
"processed_records": 10,
104+
"processed_records": _EACH_FANOUT_BATCH_SIZE,
122105
"unprocessed_records": 0,
123106
"s3_input_path": f"s3://my-bucket/input--load/unprocessed.{i}",
124107
"error_message": None,
125108
}
126-
for i in range(0, FANOUT - 1)
127-
] + [
128-
{
129-
"stage_name": "load_bulk_fanout",
130-
"processed_records": 7,
131-
"unprocessed_records": 0,
132-
"s3_input_path": f"s3://my-bucket/input--load/unprocessed.{FANOUT - 1}",
133-
"error_message": None,
134-
},
109+
for i in range(len(head_responses))
135110
]
111+
assert head_responses == expected_head_responses
112+
113+
tail_processed_records = tail_response.pop("processed_records")
114+
assert tail_processed_records <= _EACH_FANOUT_BATCH_SIZE
115+
assert tail_response == {
116+
"stage_name": "load_bulk_fanout",
117+
"unprocessed_records": 0,
118+
"s3_input_path": f"s3://my-bucket/input--load/unprocessed.{len(head_responses)}",
119+
"error_message": None,
120+
}
136121

137122
# Final state
138123
final_processed_data = pkl_loads_lz4(get_object(key=WorkerKey.LOAD))
139124
assert final_processed_data == deque([])
140125
total_size = 0
141-
for i in range(10):
126+
for i in range(_EACH_FANOUT_BATCH_SIZE):
142127
final_unprocessed_data = pkl_loads_lz4(get_object(key=f"{WorkerKey.LOAD}.{i}"))
143128
assert isinstance(final_unprocessed_data, deque)
144129
total_size += len(final_unprocessed_data)
@@ -164,3 +149,8 @@ def test_load_worker_fanout(
164149
expected_total_size += count_indexes(obj)
165150

166151
assert total_size == expected_total_size
152+
153+
total_processed_records_from_response = (
154+
tail_processed_records + _EACH_FANOUT_BATCH_SIZE * len(head_responses)
155+
)
156+
assert total_size == total_processed_records_from_response

src/layers/domain/core/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class AccreditedSystem:
3131
ID_PATTERN = re.compile(rf"^[a-zA-Z-0-9]+$")
3232

3333
class PartyKey:
34-
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{6,9}}$"
34+
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{5,9}}$"
3535
ID_PATTERN = re.compile(PARTY_KEY_REGEX)
3636

3737
class CpaId:

src/layers/sds/epr/bulk_create/bulk_create.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
from domain.core.aggregate_root import AggregateRoot
24
from domain.core.product_team.v1 import ProductTeam
35
from domain.core.questionnaire import Questionnaire, QuestionnaireResponse
@@ -91,6 +93,22 @@ def _create_complete_epr_product(
9193
]
9294

9395

96+
def _impute_manufacturer_org(
97+
item: dict[Literal["nhs_mhs_manufacturer_org", "nhs_id_code"], str]
98+
):
99+
"""
100+
Impute nhs_mhs_manufacturer_org if it is clearly invalid (not alphanumeric)
101+
or does not exist by replacing it with the nhs_id_code
102+
"""
103+
manufacturer_org = item["nhs_mhs_manufacturer_org"]
104+
item["nhs_mhs_manufacturer_org"] = (
105+
manufacturer_org
106+
if manufacturer_org is not None and manufacturer_org.isalnum()
107+
else item["nhs_id_code"]
108+
)
109+
return item
110+
111+
94112
def create_complete_epr_product(
95113
party_key_group: list[dict],
96114
mhs_device_questionnaire: Questionnaire,
@@ -110,6 +128,7 @@ def create_complete_epr_product(
110128
message_handling_systems: list[dict] = []
111129
accredited_systems: list[dict] = []
112130
for item in party_key_group:
131+
item = _impute_manufacturer_org(item)
113132
if item["object_class"].lower() == NhsMhs.OBJECT_CLASS:
114133
message_handling_systems.append(item)
115134
else:
@@ -127,7 +146,6 @@ def create_complete_epr_product(
127146
mhs_device_questionnaire=mhs_device_questionnaire,
128147
mhs_device_field_mapping=mhs_device_field_mapping,
129148
)
130-
131149
ods_code = first_mhs["nhs_mhs_manufacturer_org"]
132150
party_key = first_mhs["nhs_mhs_party_key"]
133151
product_name = first_mhs["nhs_product_name"] or first_mhs["nhs_mhs_cpa_id"]
Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
from math import ceil
2-
31
from domain.core.cpm_product.v1 import CpmProduct
42
from domain.core.device.v1 import Device
53
from domain.core.device_reference_data.v1 import DeviceReferenceData
64
from domain.core.product_team.v1 import ProductTeam
75

8-
FANOUT = 10
9-
106

117
def count_indexes(obj: Device | DeviceReferenceData | CpmProduct | ProductTeam):
128
count = 1
@@ -15,7 +11,3 @@ def count_indexes(obj: Device | DeviceReferenceData | CpmProduct | ProductTeam):
1511
if isinstance(obj, (Device)):
1612
count += len(obj.tags)
1713
return count
18-
19-
20-
def calculate_batch_size(sequence: list, n_batches: int) -> int:
21-
return ceil(len(sequence) / (n_batches)) or 1

src/layers/sds/epr/bulk_create/creators.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from domain.core.product_team_key import ProductTeamKey, ProductTeamKeyType
88
from domain.core.questionnaire import QuestionnaireResponse
99
from domain.core.root import Root
10-
from sds.epr.constants import EprNameTemplate, SdsDeviceReferenceDataPath
10+
from sds.epr.constants import (
11+
EXCEPTIONAL_ODS_CODES,
12+
EprNameTemplate,
13+
SdsDeviceReferenceDataPath,
14+
)
1115

1216

1317
def create_epr_product_team(ods_code: str) -> ProductTeam:
@@ -16,6 +20,13 @@ def create_epr_product_team(ods_code: str) -> ProductTeam:
1620
key_value=EprNameTemplate.PRODUCT_TEAM_KEY.format(ods_code=ods_code),
1721
)
1822

23+
if ods_code in EXCEPTIONAL_ODS_CODES:
24+
return ProductTeam(
25+
name=EprNameTemplate.PRODUCT_TEAM.format(ods_code=ods_code),
26+
ods_code=ods_code,
27+
keys=[product_team_key],
28+
)
29+
1930
org = Root.create_ods_organisation(ods_code=ods_code)
2031
return org.create_product_team(
2132
name=EprNameTemplate.PRODUCT_TEAM.format(ods_code=ods_code),

0 commit comments

Comments
 (0)