Skip to content

Commit ba1b60d

Browse files
[PRMP-202] Migration to Data Standards - LG table values (#793)
Co-authored-by: SWhyteAnswer <[email protected]> Co-authored-by: Sam Whyte <[email protected]>
1 parent 8afc7fc commit ba1b60d

File tree

5 files changed

+374
-24
lines changed

5 files changed

+374
-24
lines changed

lambdas/models/document_reference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,12 @@ def set_virus_scanner_result(self, updated_virus_scanner_result) -> None:
170170

171171
def set_uploaded_to_true(self):
172172
self.uploaded = True
173+
174+
def infer_doc_status(self) -> str | None:
175+
if self.deleted:
176+
return "deprecated"
177+
if self.uploaded:
178+
return "final"
179+
if self.uploading:
180+
return "preliminary"
181+
return None
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import argparse
2+
from typing import Iterable, Callable
3+
4+
from enums.snomed_codes import SnomedCodes
5+
from models.document_reference import DocumentReference
6+
from services.base.dynamo_service import DynamoDBService
7+
from utils.audit_logging_setup import LoggingService
8+
9+
10+
class VersionMigration:
11+
def __init__(self, environment: str, table_name: str, run_migration: bool = False):
12+
self.environment = environment
13+
self.table_name = table_name
14+
self.run_migration = run_migration
15+
self.logger = LoggingService("CustodianMigration")
16+
self.dynamo_service = DynamoDBService()
17+
18+
self.target_table = f"{self.environment}_{self.table_name}"
19+
20+
def main(
21+
self, entries: Iterable[dict]
22+
) -> list[tuple[str, Callable[[dict], dict | None]]]:
23+
"""
24+
Main entry point for the migration.
25+
Returns a list of (label, update function) tuples.
26+
Accepts a list of entries for Lambda-based execution, or scans the table if `entries` is None.
27+
"""
28+
self.logger.info("Starting version migration")
29+
self.logger.info(f"Target table: {self.target_table}")
30+
self.logger.info(f"Dry run mode: {not self.run_migration}")
31+
32+
if entries is None:
33+
self.logger.error("No entries provided after scanning entire table.")
34+
raise ValueError("Entries must be provided to main().")
35+
36+
return [
37+
("LGTableValues", self.get_updated_items)
38+
]
39+
40+
def process_entries(
41+
self,
42+
label: str,
43+
entries: Iterable[dict],
44+
update_fn: Callable[[dict], dict | None],
45+
):
46+
"""
47+
Processes a list of entries, applying the update function to each.
48+
Logs progress and handles dry-run mode.
49+
"""
50+
self.logger.info(f"Running {label} migration")
51+
52+
for index, entry in enumerate(entries, start=1):
53+
item_id = entry.get("ID")
54+
self.logger.info(
55+
f"[{label}] Processing item {index} (ID: {item_id})"
56+
)
57+
58+
updated_fields = update_fn(entry)
59+
if not updated_fields:
60+
self.logger.debug(
61+
f"[{label}] Item {item_id} does not require update, skipping."
62+
)
63+
continue
64+
65+
if self.run_migration:
66+
self.logger.info(f"Updating item {item_id} with {updated_fields}")
67+
try:
68+
self.dynamo_service.update_item(
69+
table_name=self.target_table,
70+
key_pair={"ID": item_id},
71+
updated_fields=updated_fields,
72+
)
73+
except Exception as e:
74+
self.logger.error(f"Failed to update item {item_id}: {str(e)}")
75+
continue
76+
else:
77+
self.logger.info(
78+
f"[Dry Run] Would update item {item_id} with {updated_fields}"
79+
)
80+
81+
self.logger.info(f"{label} migration completed.") # Moved outside the loop
82+
83+
def get_updated_items(self, entry: dict) -> dict | None:
84+
"""
85+
Aggregates updates from all update methods for a single entry.
86+
Returns a dict of fields to update, or None if no update is needed.
87+
"""
88+
update_items = {}
89+
90+
if custodian_update_items := self.get_update_custodian_items(entry):
91+
update_items.update(custodian_update_items)
92+
93+
if status_update_items := self.get_update_status_items(entry):
94+
update_items.update(status_update_items)
95+
96+
if snomed_code_update_items := self.get_update_document_snomed_code_type_items(entry):
97+
update_items.update(snomed_code_update_items)
98+
99+
if doc_status_update_items := self.get_update_doc_status_items(entry):
100+
update_items.update(doc_status_update_items)
101+
102+
if version_update_items := self.get_update_version_items(entry):
103+
update_items.update(version_update_items)
104+
105+
return update_items if update_items else None
106+
107+
def get_update_custodian_items(self, entry: dict) -> dict | None:
108+
"""
109+
Updates the 'Custodian' field if it does not match 'CurrentGpOds'.
110+
Returns a dict with the update or None.
111+
"""
112+
current_gp_ods = entry.get("CurrentGpOds")
113+
custodian = entry.get("Custodian")
114+
115+
if current_gp_ods is None:
116+
self.logger.warning(f"[Custodian] CurrentGpOds is missing for item {entry.get('ID')}")
117+
return None
118+
if current_gp_ods is None or current_gp_ods != custodian:
119+
return {"Custodian": current_gp_ods}
120+
121+
return None
122+
123+
@staticmethod
124+
def get_update_status_items(entry: dict) -> dict | None:
125+
"""
126+
Ensures the 'Status' field is set to 'current'.
127+
Returns a dict with the update or None.
128+
"""
129+
if entry.get("Status") != "current":
130+
return {"Status": "current"}
131+
return None
132+
133+
@staticmethod
134+
def get_update_document_snomed_code_type_items(entry: dict) -> dict | None:
135+
"""
136+
Ensures the 'DocumentSnomedCodeType' field matches the expected SNOMED code.
137+
Returns a dict with the update or None.
138+
"""
139+
expected_code = SnomedCodes.LLOYD_GEORGE.value.code
140+
if entry.get("DocumentSnomedCodeType") != expected_code:
141+
return {"DocumentSnomedCodeType": expected_code}
142+
return None
143+
144+
@staticmethod
145+
def get_update_version_items(entry: dict) -> dict | None:
146+
"""
147+
Ensures the 'Version' field matches the expected Version code.
148+
Returns a dict with the update or None.
149+
"""
150+
expected_version = "1"
151+
version_field = "Version"
152+
if entry.get(version_field) != expected_version:
153+
return {version_field: expected_version}
154+
return None
155+
156+
def get_update_doc_status_items(self, entry: dict) -> dict | None:
157+
"""
158+
Infers and updates the 'DocStatus' field if missing.
159+
Returns a dict with the update or None.
160+
"""
161+
try:
162+
document = DocumentReference(**entry)
163+
except Exception as e:
164+
self.logger.warning(f"[DocStatus] Skipping invalid item {entry.get('ID')}: {e}")
165+
return None
166+
167+
inferred_status = document.infer_doc_status()
168+
169+
if entry.get("uploaded") and entry.get("uploading"):
170+
self.logger.warning(f"{entry.get('ID')}: Document has a status of uploading and uploaded.")
171+
172+
if entry.get("DocStatus", "") == inferred_status:
173+
return None
174+
175+
self.logger.warning(f"{entry.get('ID')}: {inferred_status}")
176+
177+
if inferred_status:
178+
return {"DocStatus": inferred_status}
179+
180+
self.logger.warning(f"[DocStatus] Cannot determine status for item {entry.get('ID')}")
181+
return None
182+
183+
if __name__ == "__main__":
184+
parser = argparse.ArgumentParser(
185+
prog="dynamodb_migration.py",
186+
description="Migrate DynamoDB table columns",
187+
)
188+
parser.add_argument("environment", help="Environment prefix for DynamoDB table")
189+
parser.add_argument("table_name", help="DynamoDB table name to migrate")
190+
parser.add_argument(
191+
"--run-migration",
192+
action="store_true",
193+
help="Running migration, fields will be updated.",
194+
)
195+
args = parser.parse_args()
196+
197+
migration = VersionMigration(
198+
environment=args.environment,
199+
table_name=args.table_name,
200+
run_migration=args.run_migration,
201+
)
202+
203+
entries_to_process = list(
204+
migration.dynamo_service.stream_whole_table(migration.target_table)
205+
)
206+
207+
update_functions = migration.main(entries=entries_to_process)
208+
209+
for label, fn in update_functions:
210+
migration.process_entries(label=label, entries=entries_to_process, update_fn=fn)

lambdas/services/base/dynamo_service.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from typing import Iterator
23
from typing import Optional
34

45
import boto3
@@ -37,14 +38,14 @@ def get_table(self, table_name):
3738
raise e
3839

3940
def query_table_by_index(
40-
self,
41-
table_name,
42-
index_name,
43-
search_key,
44-
search_condition: str,
45-
requested_fields: list[str] = None,
46-
query_filter: Attr | ConditionBase = None,
47-
exclusive_start_key: dict = None,
41+
self,
42+
table_name,
43+
index_name,
44+
search_key,
45+
search_condition: str,
46+
requested_fields: list[str] = None,
47+
query_filter: Attr | ConditionBase = None,
48+
exclusive_start_key: dict = None,
4849
):
4950
try:
5051
table = self.get_table(table_name)
@@ -77,7 +78,7 @@ def query_table_by_index(
7778
raise e
7879

7980
def query_with_pagination(
80-
self, table_name: str, search_key: str, search_condition: str
81+
self, table_name: str, search_key: str, search_condition: str
8182
):
8283

8384
try:
@@ -133,12 +134,12 @@ def create_item(self, table_name, item):
133134
raise e
134135

135136
def update_item(
136-
self,
137-
table_name: str,
138-
key_pair: dict[str, str],
139-
updated_fields: dict,
140-
condition_expression: str = None,
141-
expression_attribute_values: dict = None,
137+
self,
138+
table_name: str,
139+
key_pair: dict[str, str],
140+
updated_fields: dict,
141+
condition_expression: str = None,
142+
expression_attribute_values: dict = None,
142143
):
143144
table = self.get_table(table_name)
144145
updated_field_names = list(updated_fields.keys())
@@ -177,10 +178,10 @@ def delete_item(self, table_name: str, key: dict):
177178
raise e
178179

179180
def scan_table(
180-
self,
181-
table_name: str,
182-
exclusive_start_key: dict = None,
183-
filter_expression: str = None,
181+
self,
182+
table_name: str,
183+
exclusive_start_key: dict = None,
184+
filter_expression: str = None,
184185
):
185186
try:
186187
table = self.get_table(table_name)
@@ -199,10 +200,10 @@ def scan_table(
199200
raise e
200201

201202
def scan_whole_table(
202-
self,
203-
table_name: str,
204-
project_expression: Optional[str] = None,
205-
filter_expression: Optional[str] = None,
203+
self,
204+
table_name: str,
205+
project_expression: Optional[str] = None,
206+
filter_expression: Optional[str] = None,
206207
) -> list[dict]:
207208
try:
208209
table = self.get_table(table_name)
@@ -265,7 +266,7 @@ def batch_get_items(self, table_name: str, key_list: list[str]):
265266
)
266267
request_items = unprocessed_keys
267268
retries += 1
268-
time.sleep((2**retries) * 0.1)
269+
time.sleep((2 ** retries) * 0.1)
269270
else:
270271
break
271272

@@ -284,3 +285,38 @@ def get_item(self, table_name: str, key: dict):
284285
str(e), {"Result": f"Unable to retrieve item from table: {table_name}"}
285286
)
286287
raise e
288+
289+
def stream_whole_table(
290+
self,
291+
table_name: str,
292+
filter_expression: Optional[str] = None,
293+
projection_expression: Optional[str] = None,
294+
) -> Iterator[dict]:
295+
"""
296+
Streams all items from a DynamoDB table using pagination.
297+
Yields one item at a time instead of loading everything into memory.
298+
"""
299+
try:
300+
table = self.get_table(table_name)
301+
scan_kwargs = {}
302+
303+
if filter_expression:
304+
scan_kwargs["FilterExpression"] = filter_expression
305+
if projection_expression:
306+
scan_kwargs["ProjectionExpression"] = projection_expression
307+
308+
response = table.scan(**scan_kwargs)
309+
310+
for item in response.get("Items", []):
311+
yield item
312+
313+
while "LastEvaluatedKey" in response:
314+
response = table.scan(
315+
ExclusiveStartKey=response["LastEvaluatedKey"], **scan_kwargs
316+
)
317+
for item in response.get("Items", []):
318+
yield item
319+
320+
except ClientError as e:
321+
logger.error(str(e), {"Result": f"Unable to stream table: {table_name}"})
322+
raise e

0 commit comments

Comments
 (0)