diff --git a/backend/src/controller/__init__.py b/backend/src/controller/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/src/controller/aws_apig_event_utils.py b/backend/src/controller/aws_apig_event_utils.py new file mode 100644 index 000000000..63146694c --- /dev/null +++ b/backend/src/controller/aws_apig_event_utils.py @@ -0,0 +1,27 @@ +"""Utility module for interacting with the AWS API Gateway event provided to controllers""" +from typing import Optional + +from aws_lambda_typing.events import APIGatewayProxyEventV1 + +from controller.constants import SUPPLIER_SYSTEM_HEADER_NAME +from models.errors import UnauthorizedError +from utils import dict_utils + + +def get_path_parameter(event: APIGatewayProxyEventV1, param_name: str) -> str: + return dict_utils.get_field( + event["pathParameters"], + param_name, + default="" + ) + + +def get_supplier_system_header(event: APIGatewayProxyEventV1) -> str: + """Retrieves the supplier system header from the API Gateway event""" + supplier_system: Optional[str] = dict_utils.get_field(dict(event), "headers", SUPPLIER_SYSTEM_HEADER_NAME) + + if supplier_system is None: + # SupplierSystem header must be provided for looking up permissions + raise UnauthorizedError() + + return supplier_system diff --git a/backend/src/controller/aws_apig_response_utils.py b/backend/src/controller/aws_apig_response_utils.py new file mode 100644 index 000000000..858a799eb --- /dev/null +++ b/backend/src/controller/aws_apig_response_utils.py @@ -0,0 +1,24 @@ +"""Utility module providing helper functions for dealing with response formats for AWS API Gateway""" +import json +from typing import Optional + + +def create_response( + status_code: int, + body: Optional[dict | str] = None, + headers: Optional[dict] = None +): + """Creates response body as per Lambda -> API Gateway proxy integration""" + if body is not None: + if isinstance(body, dict): + body = json.dumps(body) + if headers: + headers["Content-Type"] = "application/fhir+json" + else: + headers = {"Content-Type": "application/fhir+json"} + + return { + "statusCode": status_code, + "headers": headers if headers else {}, + **({"body": body} if body else {}), + } diff --git a/backend/src/controller/constants.py b/backend/src/controller/constants.py new file mode 100644 index 000000000..7e6b2357e --- /dev/null +++ b/backend/src/controller/constants.py @@ -0,0 +1,5 @@ +"""FHIR Controller constants""" + + +SUPPLIER_SYSTEM_HEADER_NAME = "SupplierSystem" +E_TAG_HEADER_NAME = "E-Tag" diff --git a/backend/src/controller/fhir_api_exception_handler.py b/backend/src/controller/fhir_api_exception_handler.py new file mode 100644 index 000000000..fc793f683 --- /dev/null +++ b/backend/src/controller/fhir_api_exception_handler.py @@ -0,0 +1,42 @@ +"""Module for the global FHIR API exception handler""" +import functools +import uuid +from typing import Callable, Type + +from clients import logger +from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE +from controller.aws_apig_response_utils import create_response +from models.errors import UnauthorizedVaxError, UnauthorizedError, ResourceNotFoundError, create_operation_outcome, \ + Severity, Code + + +_CUSTOM_EXCEPTION_TO_STATUS_MAP: dict[Type[Exception], int] = { + UnauthorizedError: 403, + UnauthorizedVaxError: 403, + ResourceNotFoundError: 404 +} + + +def fhir_api_exception_handler(function: Callable) -> Callable: + """Decorator to handle any expected FHIR API exceptions or unexpected exception and provide a valid response to + the client""" + + @functools.wraps(function) + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except tuple(_CUSTOM_EXCEPTION_TO_STATUS_MAP) as exc: + status_code = _CUSTOM_EXCEPTION_TO_STATUS_MAP[type(exc)] + return create_response(status_code=status_code, body=exc.to_operation_outcome()) + except Exception: # pylint: disable = broad-exception-caught + logger.exception("Unhandled exception") + server_error = create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.server_error, + diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, + ) + return create_response(500, server_error) + + return wrapper + diff --git a/backend/src/fhir_batch_controller.py b/backend/src/controller/fhir_batch_controller.py similarity index 91% rename from backend/src/fhir_batch_controller.py rename to backend/src/controller/fhir_batch_controller.py index 90f68a40c..fbfb99a3f 100644 --- a/backend/src/fhir_batch_controller.py +++ b/backend/src/controller/fhir_batch_controller.py @@ -1,7 +1,7 @@ """Function to send the request directly to lambda (or return appropriate diagnostics if this is not possible)""" -from fhir_batch_service import ImmunizationBatchService -from fhir_batch_repository import ImmunizationBatchRepository +from service.fhir_batch_service import ImmunizationBatchService +from repository.fhir_batch_repository import ImmunizationBatchRepository def make_batch_controller(): diff --git a/backend/src/fhir_controller.py b/backend/src/controller/fhir_controller.py similarity index 78% rename from backend/src/fhir_controller.py rename to backend/src/controller/fhir_controller.py index 577384346..2804a34ce 100644 --- a/backend/src/fhir_controller.py +++ b/backend/src/controller/fhir_controller.py @@ -6,11 +6,13 @@ from decimal import Decimal from typing import Optional from aws_lambda_typing.events import APIGatewayProxyEventV1 -from fhir.resources.R4B.immunization import Immunization -from boto3 import client as boto3_client -from fhir_repository import ImmunizationRepository, create_table -from fhir_service import FhirService, UpdateOutcome, get_service_url +from controller.aws_apig_event_utils import get_supplier_system_header, get_path_parameter +from controller.aws_apig_response_utils import create_response +from controller.constants import E_TAG_HEADER_NAME +from controller.fhir_api_exception_handler import fhir_api_exception_handler +from repository.fhir_repository import ImmunizationRepository, create_table +from service.fhir_service import FhirService, UpdateOutcome, get_service_url from models.errors import ( Severity, Code, @@ -27,9 +29,6 @@ from parameter_parser import process_params, process_search_params, create_query_string import urllib.parse -sqs_client = boto3_client("sqs", region_name="eu-west-2") -queue_url = os.getenv("SQS_QUEUE_URL", "Queue_url") - def make_controller( immunization_env: str = os.getenv("IMMUNIZATION_ENV"), @@ -58,7 +57,7 @@ def get_immunization_by_identifier(self, aws_event) -> dict: else: raise UnauthorizedError() except UnauthorizedError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) body = aws_event["body"] if query_params and body: error = create_operation_outcome( @@ -67,7 +66,7 @@ def get_immunization_by_identifier(self, aws_event) -> dict: code=Code.invalid, diagnostics=('Parameters may not be duplicated. Use commas for "or".'), ) - return self.create_response(400, error) + return create_response(400, error) identifier, element, not_required, has_imms_identifier, has_element = self.fetch_identifier_system_and_element( aws_event ) @@ -75,55 +74,33 @@ def get_immunization_by_identifier(self, aws_event) -> dict: return self.create_response_for_identifier(not_required, has_imms_identifier, has_element) # If not found, retrieve from multiValueQueryStringParameters if id_error := self._validate_identifier_system(identifier, element): - return self.create_response(400, id_error) + return create_response(400, id_error) identifiers = identifier.replace("|", "#") supplier_system = self._identify_supplier_system(aws_event) try: if resource := self.fhir_service.get_immunization_by_identifier( identifiers, supplier_system, identifier, element): - return FhirController.create_response(200, resource) + return create_response(200, resource) except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) + + @fhir_api_exception_handler + def get_immunization_by_id(self, aws_event: APIGatewayProxyEventV1) -> dict: + imms_id = get_path_parameter(aws_event, "id") - def get_immunization_by_id(self, aws_event) -> dict: - imms_id = aws_event["pathParameters"]["id"] if id_error := self._validate_id(imms_id): - return self.create_response(400, id_error) + return create_response(400, id_error) - try: - if aws_event.get("headers"): - supplier_system = self._identify_supplier_system(aws_event) - else: - raise UnauthorizedError() - except UnauthorizedError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + supplier_system = get_supplier_system_header(aws_event) - try: - if resource := self.fhir_service.get_immunization_by_id(imms_id, supplier_system): - version = str() - if isinstance(resource, Immunization): - resp = resource - else: - resp = resource["Resource"] - if resource.get("Version"): - version = resource["Version"] - return FhirController.create_response(200, resp.json(), {"E-Tag": version}) - else: - msg = "The requested resource was not found." - id_error = create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.not_found, - diagnostics=msg, - ) - return FhirController.create_response(404, id_error) - except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + resource, version = self.fhir_service.get_immunization_and_version_by_id(imms_id, supplier_system) + + return create_response(200, resource.json(), {E_TAG_HEADER_NAME: version}) def create_immunization(self, aws_event): if not aws_event.get("headers"): - return self.create_response( + return create_response( 403, create_operation_outcome( resource_id=str(uuid.uuid4()), @@ -148,19 +125,19 @@ def create_immunization(self, aws_event): code=Code.invariant, diagnostics=resource["diagnostics"], ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) else: location = f"{get_service_url()}/Immunization/{resource.id}" version = "1" - return self.create_response(201, None, {"Location": location, "E-Tag": version}) + return create_response(201, None, {"Location": location, "E-Tag": version}) except ValidationError as error: - return self.create_response(400, error.to_operation_outcome()) + return create_response(400, error.to_operation_outcome()) except IdentifierDuplicationError as duplicate: - return self.create_response(422, duplicate.to_operation_outcome()) + return create_response(422, duplicate.to_operation_outcome()) except UnhandledResponseError as unhandled_error: - return self.create_response(500, unhandled_error.to_operation_outcome()) + return create_response(500, unhandled_error.to_operation_outcome()) except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) def update_immunization(self, aws_event): try: @@ -169,13 +146,13 @@ def update_immunization(self, aws_event): else: raise UnauthorizedError() except UnauthorizedError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) supplier_system = self._identify_supplier_system(aws_event) # Validate the imms id - start if id_error := self._validate_id(imms_id): - return FhirController.create_response(400, json.dumps(id_error)) + return create_response(400, json.dumps(id_error)) # Validate the imms id - end # Validate the body of the request - start @@ -189,7 +166,7 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=f"Validation errors: The provided immunization id:{imms_id} doesn't match with the content of the request body", ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) # Validate the imms id in the path params and body of request - end except json.decoder.JSONDecodeError as e: return self._create_bad_request(f"Request's body contains malformed JSON: {e}") @@ -207,7 +184,7 @@ def update_immunization(self, aws_event): code=Code.not_found, diagnostics=f"Validation errors: The requested immunization resource with id:{imms_id} was not found.", ) - return self.create_response(404, json.dumps(exp_error)) + return create_response(404, json.dumps(exp_error)) if "diagnostics" in existing_record: exp_error = create_operation_outcome( @@ -216,9 +193,9 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=existing_record["diagnostics"], ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) except ValidationError as error: - return self.create_response(400, error.to_operation_outcome()) + return create_response(400, error.to_operation_outcome()) # Validate if the imms resource does not exist - end existing_resource_version = int(existing_record["Version"]) @@ -244,7 +221,7 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics="Validation errors: Immunization resource version not specified in the request headers", ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) # Validate if imms resource version is part of the request - end # Validate the imms resource version provided in the request headers - start @@ -258,7 +235,7 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=f"Validation errors: Immunization resource version:{resource_version} in the request headers is invalid.", ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) # Validate the imms resource version provided in the request headers - end # Validate if resource version has changed since the last retrieve - start @@ -269,7 +246,7 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=f"Validation errors: The requested immunization resource {imms_id} has changed since the last retrieve.", ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) if existing_resource_version < resource_version_header: exp_error = create_operation_outcome( resource_id=str(uuid.uuid4()), @@ -277,7 +254,7 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=f"Validation errors: The requested immunization resource {imms_id} version is inconsistent with the existing version.", ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) # Validate if resource version has changed since the last retrieve - end # Check if the record is reinstated record - start @@ -308,15 +285,15 @@ def update_immunization(self, aws_event): code=Code.invariant, diagnostics=resource["diagnostics"], ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) if outcome == UpdateOutcome.UPDATE: - return self.create_response(200, None, {"E-Tag": updated_version}) #include e-tag here, is it not included in the response resource + return create_response(200, None, {"E-Tag": updated_version}) #include e-tag here, is it not included in the response resource except ValidationError as error: - return self.create_response(400, error.to_operation_outcome()) + return create_response(400, error.to_operation_outcome()) except IdentifierDuplicationError as duplicate: - return self.create_response(422, duplicate.to_operation_outcome()) + return create_response(422, duplicate.to_operation_outcome()) except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) def delete_immunization(self, aws_event): try: @@ -325,24 +302,24 @@ def delete_immunization(self, aws_event): else: raise UnauthorizedError() except UnauthorizedError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) # Validate the imms id if id_error := self._validate_id(imms_id): - return FhirController.create_response(400, json.dumps(id_error)) + return create_response(400, json.dumps(id_error)) supplier_system = self._identify_supplier_system(aws_event) try: self.fhir_service.delete_immunization(imms_id, supplier_system) - return self.create_response(204) + return create_response(204) except ResourceNotFoundError as not_found: - return self.create_response(404, not_found.to_operation_outcome()) + return create_response(404, not_found.to_operation_outcome()) except UnhandledResponseError as unhandled_error: - return self.create_response(500, unhandled_error.to_operation_outcome()) + return create_response(500, unhandled_error.to_operation_outcome()) except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) def search_immunizations(self, aws_event: APIGatewayProxyEventV1) -> dict: try: @@ -359,7 +336,7 @@ def search_immunizations(self, aws_event: APIGatewayProxyEventV1) -> dict: else: raise UnauthorizedError() except UnauthorizedError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) try: result, request_contained_unauthorised_vaccs = self.fhir_service.search_immunizations( @@ -371,7 +348,7 @@ def search_immunizations(self, aws_event: APIGatewayProxyEventV1) -> dict: search_params.date_to, ) except UnauthorizedVaxError as unauthorized: - return self.create_response(403, unauthorized.to_operation_outcome()) + return create_response(403, unauthorized.to_operation_outcome()) if "diagnostics" in result: exp_error = create_operation_outcome( @@ -380,7 +357,7 @@ def search_immunizations(self, aws_event: APIGatewayProxyEventV1) -> dict: code=Code.invariant, diagnostics=result["diagnostics"], ) - return self.create_response(400, json.dumps(exp_error)) + return create_response(400, json.dumps(exp_error)) # Workaround for fhir.resources JSON removing the empty "entry" list. result_json_dict: dict = json.loads(result.json()) if "entry" in result_json_dict: @@ -404,7 +381,7 @@ def search_immunizations(self, aws_event: APIGatewayProxyEventV1) -> dict: if "entry" not in result_json_dict: result_json_dict["entry"] = [] result_json_dict["total"] = 0 - return self.create_response(200, json.dumps(result_json_dict)) + return create_response(200, json.dumps(result_json_dict)) def _validate_id(self, _id: str) -> Optional[dict]: if not re.match(self.immunization_id_pattern, _id): @@ -415,8 +392,8 @@ def _validate_id(self, _id: str) -> Optional[dict]: code=Code.invalid, diagnostics=msg, ) - else: - return None + + return None def _validate_identifier_system(self, _id: str, _elements: str) -> Optional[dict]: if not _id: @@ -462,7 +439,7 @@ def _create_bad_request(self, message): code=Code.invalid, diagnostics=message, ) - return self.create_response(400, error) + return create_response(400, error) def fetch_identifier_system_and_element(self, event: dict): """ @@ -531,7 +508,7 @@ def create_response_for_identifier(self, not_required, has_identifier, has_eleme code=Code.server_error, diagnostics="Search parameter should have either identifier or patient.identifier", ) - return self.create_response(400, error) + return create_response(400, error) if not_required and has_element: error = create_operation_outcome( @@ -540,27 +517,11 @@ def create_response_for_identifier(self, not_required, has_identifier, has_eleme code=Code.server_error, diagnostics="Search parameter _elements must have the following parameter: identifier", ) - return self.create_response(400, error) - - @staticmethod - def create_response(status_code, body=None, headers=None): - if body: - if isinstance(body, dict): - body = json.dumps(body) - if headers: - headers["Content-Type"] = "application/fhir+json" - else: - headers = {"Content-Type": "application/fhir+json"} - - return { - "statusCode": status_code, - "headers": headers if headers else {}, - **({"body": body} if body else {}), - } + return create_response(400, error) @staticmethod def _identify_supplier_system(aws_event): supplier_system = aws_event["headers"]["SupplierSystem"] if not supplier_system: - raise UnauthorizedError("SupplierSystem header is missing") + raise UnauthorizedError() return supplier_system diff --git a/backend/src/create_imms_handler.py b/backend/src/create_imms_handler.py index e3bd192f8..630d6c32d 100644 --- a/backend/src/create_imms_handler.py +++ b/backend/src/create_imms_handler.py @@ -3,8 +3,8 @@ import pprint import uuid - -from fhir_controller import FhirController, make_controller +from controller.aws_apig_response_utils import create_response +from controller.fhir_controller import FhirController, make_controller from local_lambda import load_string from models.errors import Severity, Code, create_operation_outcome from log_structure import function_info @@ -29,7 +29,7 @@ def create_immunization(event, controller: FhirController): code=Code.server_error, diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, ) - return FhirController.create_response(500, exp_error) + return create_response(500, exp_error) if __name__ == "__main__": diff --git a/backend/src/delete_imms_handler.py b/backend/src/delete_imms_handler.py index 14458de2f..4838c7790 100644 --- a/backend/src/delete_imms_handler.py +++ b/backend/src/delete_imms_handler.py @@ -3,8 +3,8 @@ import pprint import uuid - -from fhir_controller import FhirController, make_controller +from controller.aws_apig_response_utils import create_response +from controller.fhir_controller import FhirController, make_controller from models.errors import Severity, Code, create_operation_outcome from log_structure import function_info from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE @@ -28,7 +28,7 @@ def delete_immunization(event, controller: FhirController): code=Code.server_error, diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, ) - return FhirController.create_response(500, exp_error) + return create_response(500, exp_error) if __name__ == "__main__": diff --git a/backend/src/forwarding_batch_lambda.py b/backend/src/forwarding_batch_lambda.py index 70e832c0b..14178b080 100644 --- a/backend/src/forwarding_batch_lambda.py +++ b/backend/src/forwarding_batch_lambda.py @@ -8,8 +8,8 @@ from datetime import datetime from batch.batch_filename_to_events_mapper import BatchFilenameToEventsMapper -from fhir_batch_repository import create_table -from fhir_batch_controller import ImmunizationBatchController, make_batch_controller +from repository.fhir_batch_repository import create_table +from controller.fhir_batch_controller import ImmunizationBatchController, make_batch_controller from clients import sqs_client from models.errors import ( MessageNotSuccessfulError, diff --git a/backend/src/get_imms_handler.py b/backend/src/get_imms_handler.py index 3ae4792ee..5c31f11f0 100644 --- a/backend/src/get_imms_handler.py +++ b/backend/src/get_imms_handler.py @@ -4,7 +4,7 @@ import uuid -from fhir_controller import FhirController, make_controller +from controller.fhir_controller import FhirController, make_controller from models.errors import Severity, Code, create_operation_outcome from log_structure import function_info from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE @@ -18,17 +18,7 @@ def get_imms_handler(event, _context): def get_immunization_by_id(event, controller: FhirController): - try: - return controller.get_immunization_by_id(event) - except Exception: # pylint: disable = broad-exception-caught - logger.exception("Unhandled exception") - exp_error = create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.server_error, - diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, - ) - return FhirController.create_response(500, exp_error) + return controller.get_immunization_by_id(event) if __name__ == "__main__": diff --git a/backend/src/models/fhir_immunization_post_validators.py b/backend/src/models/fhir_immunization_post_validators.py index c874b2221..cea4e7329 100644 --- a/backend/src/models/fhir_immunization_post_validators.py +++ b/backend/src/models/fhir_immunization_post_validators.py @@ -1,12 +1,11 @@ "FHIR Immunization Post Validators" from models.errors import MandatoryError -from models.obtain_field_value import ObtainFieldValue from models.validation_sets import ValidationSets from models.mandation_functions import MandationFunctions from models.field_names import FieldNames from models.field_locations import FieldLocations -from base_utils.base_utils import obtain_field_value, obtain_field_location +from models.utils.base_utils import obtain_field_value, obtain_field_location class PostValidators: diff --git a/backend/src/models/mandation_functions.py b/backend/src/models/mandation_functions.py index 9b1cdc7dc..99d2ef526 100644 --- a/backend/src/models/mandation_functions.py +++ b/backend/src/models/mandation_functions.py @@ -3,9 +3,6 @@ from dataclasses import dataclass from models.errors import MandatoryError -from models.field_locations import FieldLocations -from models.field_names import FieldNames -from base_utils.base_utils import obtain_field_value @dataclass diff --git a/backend/src/base_utils/base_utils.py b/backend/src/models/utils/base_utils.py similarity index 100% rename from backend/src/base_utils/base_utils.py rename to backend/src/models/utils/base_utils.py diff --git a/backend/src/models/utils/validation_utils.py b/backend/src/models/utils/validation_utils.py index 742bc0d3a..60eec9b6f 100644 --- a/backend/src/models/utils/validation_utils.py +++ b/backend/src/models/utils/validation_utils.py @@ -4,7 +4,7 @@ from typing import Union from .generic_utils import create_diagnostics_error -from base_utils.base_utils import obtain_field_location +from models.utils.base_utils import obtain_field_location from models.obtain_field_value import ObtainFieldValue from models.field_names import FieldNames from models.errors import MandatoryError @@ -54,7 +54,7 @@ def convert_disease_codes_to_vaccine_type(disease_codes_input: list) -> Union[st """ key = ":".join(sorted(disease_codes_input)) vaccine_type = redis_client.hget(Constants.DISEASES_TO_VACCINE_TYPE_HASH_KEY, key) - + if not vaccine_type: raise ValueError( f"Validation errors: protocolApplied[0].targetDisease[*].coding[?(@.system=='http://snomed.info/sct')].code - " diff --git a/backend/src/repository/__init__.py b/backend/src/repository/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/src/fhir_batch_repository.py b/backend/src/repository/fhir_batch_repository.py similarity index 98% rename from backend/src/fhir_batch_repository.py rename to backend/src/repository/fhir_batch_repository.py index f9c259dc3..3730adfe0 100644 --- a/backend/src/fhir_batch_repository.py +++ b/backend/src/repository/fhir_batch_repository.py @@ -87,8 +87,6 @@ def __init__(self, imms: dict, vax_type: str, supplier: str, version: int): class ImmunizationBatchRepository: - def __init__(self): - pass def create_immunization( self, immunization: any, supplier_system: str, vax_type: str, table: any, is_present: bool @@ -252,7 +250,7 @@ def _perform_dynamo_update( else Attr("PK").eq(attr.pk) & Attr("DeletedAt").not_exists() ) if deleted_at_required and update_reinstated == False: - ExpressionAttributeValues = { + expression_attribute_values = { ":timestamp": attr.timestamp, ":patient_pk": attr.patient_pk, ":patient_sk": attr.patient_sk, @@ -263,7 +261,7 @@ def _perform_dynamo_update( ":respawn": "reinstated", } else: - ExpressionAttributeValues = { + expression_attribute_values = { ":timestamp": attr.timestamp, ":patient_pk": attr.patient_pk, ":patient_sk": attr.patient_sk, @@ -279,7 +277,7 @@ def _perform_dynamo_update( ExpressionAttributeNames={ "#imms_resource": "Resource", }, - ExpressionAttributeValues=ExpressionAttributeValues, + ExpressionAttributeValues=expression_attribute_values, ReturnValues="ALL_NEW", ConditionExpression=condition_expression, ) diff --git a/backend/src/fhir_repository.py b/backend/src/repository/fhir_repository.py similarity index 98% rename from backend/src/fhir_repository.py rename to backend/src/repository/fhir_repository.py index fed207595..80ebe765e 100644 --- a/backend/src/fhir_repository.py +++ b/backend/src/repository/fhir_repository.py @@ -102,19 +102,17 @@ def get_immunization_by_identifier(self, identifier_pk: str) -> tuple[Optional[d else: return None, None - def get_immunization_by_id(self, imms_id: str) -> Optional[dict]: + def get_immunization_and_version_by_id(self, imms_id: str) -> tuple[Optional[dict], Optional[str]]: response = self.table.get_item(Key={"PK": _make_immunization_pk(imms_id)}) item = response.get("Item") if not item: - return None + return None, None + if item.get("DeletedAt") and item["DeletedAt"] != "reinstated": - return None + return None, None - return { - "Resource": json.loads(item["Resource"]), - "Version": item["Version"] - } + return json.loads(item.get("Resource", {})), str(item.get("Version", "")) def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]: response = self.table.get_item(Key={"PK": _make_immunization_pk(imms_id)}) diff --git a/backend/src/search_imms_handler.py b/backend/src/search_imms_handler.py index acf9dd5d2..513727ec7 100644 --- a/backend/src/search_imms_handler.py +++ b/backend/src/search_imms_handler.py @@ -6,8 +6,8 @@ from aws_lambda_typing import context as context_, events - -from fhir_controller import FhirController, make_controller +from controller.aws_apig_response_utils import create_response +from controller.fhir_controller import FhirController, make_controller from models.errors import Severity, Code, create_operation_outcome from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, MAX_RESPONSE_SIZE_BYTES from log_structure import function_info @@ -64,7 +64,7 @@ def search_imms(event: events.APIGatewayProxyEventV1, controller: FhirController code=Code.invalid, diagnostics="Search returned too many results. Please narrow down the search", ) - return FhirController.create_response(400, exp_error) + return create_response(400, exp_error) return response except Exception: # pylint: disable = broad-exception-caught logger.exception("Unhandled exception") @@ -74,7 +74,7 @@ def search_imms(event: events.APIGatewayProxyEventV1, controller: FhirController code=Code.server_error, diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, ) - return FhirController.create_response(500, exp_error) + return create_response(500, exp_error) if __name__ == "__main__": diff --git a/backend/src/service/__init__.py b/backend/src/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/src/fhir_batch_service.py b/backend/src/service/fhir_batch_service.py similarity index 97% rename from backend/src/fhir_batch_service.py rename to backend/src/service/fhir_batch_service.py index 6632bfc3a..c618c4c95 100644 --- a/backend/src/fhir_batch_service.py +++ b/backend/src/service/fhir_batch_service.py @@ -1,5 +1,5 @@ from pydantic import ValidationError -from fhir_batch_repository import ImmunizationBatchRepository +from repository.fhir_batch_repository import ImmunizationBatchRepository from models.errors import CustomValidationError from models.fhir_immunization import ImmunizationValidator from models.errors import MandatoryError diff --git a/backend/src/fhir_service.py b/backend/src/service/fhir_service.py similarity index 95% rename from backend/src/fhir_service.py rename to backend/src/service/fhir_service.py index 996725993..0e4c50921 100644 --- a/backend/src/fhir_service.py +++ b/backend/src/service/fhir_service.py @@ -17,7 +17,7 @@ import parameter_parser from authorisation.api_operation_code import ApiOperationCode from authorisation.authoriser import Authoriser -from fhir_repository import ImmunizationRepository +from repository.fhir_repository import ImmunizationRepository from models.errors import InvalidPatientId, CustomValidationError, UnauthorizedVaxError, ResourceNotFoundError from models.fhir_immunization import ImmunizationValidator from models.utils.generic_utils import nhs_number_mod11_check, get_occurrence_datetime, create_diagnostics, form_json, get_contained_patient @@ -31,7 +31,7 @@ def get_service_url(service_env: str = os.getenv("IMMUNIZATION_ENV"), service_base_path: str = os.getenv("IMMUNIZATION_BASE_PATH") ) -> str: - + if not service_base_path: service_base_path = "immunisation-fhir-api/FHIR/R4" @@ -83,25 +83,21 @@ def get_immunization_by_identifier( imms_resp['resource'] = filtered_resource return form_json(imms_resp, element, identifier, base_url) - def get_immunization_by_id(self, imms_id: str, supplier_system: str) -> Optional[dict]: + def get_immunization_and_version_by_id(self, imms_id: str, supplier_system: str) -> tuple[Immunization, str]: """ - Get an Immunization by its ID. Return None if it is not found. If the patient doesn't have an NHS number, - return the Immunization. + Get an Immunization by its ID. Returns the immunization entity and version number. """ - if not (imms_resp := self.immunization_repo.get_immunization_by_id(imms_id)): - return None + resource, version = self.immunization_repo.get_immunization_and_version_by_id(imms_id) + + if resource is None: + raise ResourceNotFoundError(resource_type="Immunization", resource_id=imms_id) - # Returns the Immunisation full resource with no obfuscation - resource = imms_resp.get("Resource", {}) vaccination_type = get_vaccine_type(resource) if not self.authoriser.authorise(supplier_system, ApiOperationCode.READ, {vaccination_type}): raise UnauthorizedVaxError() - return { - "Version": imms_resp.get("Version", ""), - "Resource": Immunization.parse_obj(resource), - } + return Immunization.parse_obj(resource), version def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]: """ @@ -234,12 +230,12 @@ def delete_immunization(self, imms_id: str, supplier_system: str) -> Immunizatio Exception will be raised if resource does not exist. Multiple calls to this method won't change the record in the database. """ - existing_immunisation = self.immunization_repo.get_immunization_by_id(imms_id) + existing_immunisation, _ = self.immunization_repo.get_immunization_and_version_by_id(imms_id) if not existing_immunisation: raise ResourceNotFoundError(resource_type="Immunization", resource_id=imms_id) - vaccination_type = get_vaccine_type(existing_immunisation.get("Resource", {})) + vaccination_type = get_vaccine_type(existing_immunisation) if not self.authoriser.authorise(supplier_system, ApiOperationCode.DELETE, {vaccination_type}): raise UnauthorizedVaxError() diff --git a/backend/src/update_imms_handler.py b/backend/src/update_imms_handler.py index 7d6c89794..e8af7f06b 100644 --- a/backend/src/update_imms_handler.py +++ b/backend/src/update_imms_handler.py @@ -3,8 +3,8 @@ import pprint import uuid - -from fhir_controller import FhirController, make_controller +from controller.aws_apig_response_utils import create_response +from controller.fhir_controller import FhirController, make_controller from local_lambda import load_string from models.errors import Severity, Code, create_operation_outcome from log_structure import function_info @@ -29,7 +29,7 @@ def update_imms(event, controller: FhirController): code=Code.server_error, diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, ) - return FhirController.create_response(500, exp_error) + return create_response(500, exp_error) if __name__ == "__main__": @@ -49,4 +49,3 @@ def update_imms(event, controller: FhirController): pprint.pprint(event) pprint.pprint(update_imms_handler(event, {})) - \ No newline at end of file diff --git a/backend/src/utils/__init__.py b/backend/src/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/src/utils/dict_utils.py b/backend/src/utils/dict_utils.py new file mode 100644 index 000000000..2874dc07f --- /dev/null +++ b/backend/src/utils/dict_utils.py @@ -0,0 +1,21 @@ +"""Generic helper module for Python dictionary utility functions""" +from typing import Optional, Any + + +def get_field(target_dict: dict, *args: str, default: Optional[Any] = None) -> Any: + """Safely retrieves a value from a dictionary. Supports nested dictionaries.""" + if not target_dict or not isinstance(target_dict, dict): + return default + + latest_nested_dict = dict(target_dict) + + for key in args: + if key not in latest_nested_dict: + return default + + if not isinstance(latest_nested_dict[key], dict): + return latest_nested_dict[key] + + latest_nested_dict = latest_nested_dict[key] + + return latest_nested_dict diff --git a/backend/tests/controller/__init__.py b/backend/tests/controller/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/controller/test_fhir_api_exception_handler.py b/backend/tests/controller/test_fhir_api_exception_handler.py new file mode 100644 index 000000000..eb4438437 --- /dev/null +++ b/backend/tests/controller/test_fhir_api_exception_handler.py @@ -0,0 +1,71 @@ +import json +import unittest +from unittest.mock import patch + +from controller.fhir_api_exception_handler import fhir_api_exception_handler +from models.errors import UnauthorizedError, UnauthorizedVaxError, ResourceNotFoundError + + +class TestFhirApiExceptionHandler(unittest.TestCase): + def setUp(self): + self.logger_patcher = patch("controller.fhir_api_exception_handler.logger") + self.mock_logger = self.logger_patcher.start() + + def tearDown(self): + patch.stopall() + + def test_exception_handler_does_nothing_when_no_exception_occurs(self): + """Test that when the wrapped function returns successfully then the wrapper does nothing""" + @fhir_api_exception_handler + def dummy_func(): + return "Hello World" + + self.mock_logger.exception.assert_not_called() + self.assertEqual(dummy_func(), "Hello World") + + def test_exception_handler_handles_custom_exception_and_returns_fhir_response(self): + """Test that custom exceptions are handled by the wrapper and a valid response is returned to the client""" + test_cases = [ + (UnauthorizedError(), 403, "forbidden", "Unauthorized request"), + (UnauthorizedVaxError(), 403, "forbidden", "Unauthorized request for vaccine type"), + (ResourceNotFoundError(resource_type="Immunization", resource_id="123"), 404, "not-found", + "Immunization resource does not exist. ID: 123") + ] + + for error, expected_status, expected_code, expected_message in test_cases: + with self.subTest(msg=f"Test {error.__class__.__name__}"): + + @fhir_api_exception_handler + def dummy_func(): + raise error + + response = dummy_func() + + self.mock_logger.exception.assert_not_called() + + operation_outcome = json.loads(response["body"]) + self.assertEqual(response["statusCode"], expected_status) + self.assertEqual(operation_outcome["resourceType"], "OperationOutcome") + self.assertEqual(operation_outcome["issue"][0]["code"], expected_code) + self.assertEqual(operation_outcome["issue"][0]["diagnostics"], expected_message) + + + def test_exception_handler_logs_exception_when_unexpected_error_occurs(self): + """Test that when an unexpected exception occurs the exception is logged and an appropriate response is + returned""" + @fhir_api_exception_handler + def dummy_func(): + raise Exception("Something went very wrong") + + response = dummy_func() + + self.mock_logger.exception.assert_called_once_with("Unhandled exception") + + operation_outcome = json.loads(response["body"]) + self.assertEqual(response["statusCode"], 500) + self.assertEqual(operation_outcome["resourceType"], "OperationOutcome") + self.assertEqual(operation_outcome["issue"][0]["code"], "exception") + self.assertEqual( + operation_outcome["issue"][0]["diagnostics"], + "Unable to process request. Issue may be transient." + ) diff --git a/backend/tests/test_fhir_batch_controller.py b/backend/tests/controller/test_fhir_batch_controller.py similarity index 95% rename from backend/tests/test_fhir_batch_controller.py rename to backend/tests/controller/test_fhir_batch_controller.py index c8abcd35c..084291f8e 100644 --- a/backend/tests/test_fhir_batch_controller.py +++ b/backend/tests/controller/test_fhir_batch_controller.py @@ -1,16 +1,16 @@ import unittest import uuid from unittest.mock import Mock, create_autospec -from tests.utils.immunization_utils import create_covid_19_immunization -from fhir_batch_service import ImmunizationBatchService -from fhir_batch_repository import ImmunizationBatchRepository +from testing_utils.immunization_utils import create_covid_19_immunization +from service.fhir_batch_service import ImmunizationBatchService +from repository.fhir_batch_repository import ImmunizationBatchRepository from models.errors import ( ResourceNotFoundError, UnhandledResponseError, CustomValidationError, IdentifierDuplicationError ) -from fhir_batch_controller import ImmunizationBatchController +from controller.fhir_batch_controller import ImmunizationBatchController class TestCreateImmunizationBatchController(unittest.TestCase): @@ -36,9 +36,9 @@ def test_send_request_to_dynamo_create_success(self): } self.mock_service.create_immunization.return_value = imms_id - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, imms_id) self.mock_service.create_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -55,7 +55,7 @@ def test_send_request_to_dynamo_create_badrequest(self): imms_id = str(uuid.uuid4()) imms = create_covid_19_immunization(imms_id) create_result = CustomValidationError(message = "Validation errors: contained[?(@.resourceType=='Patient')].identifier[0].value does not exists") - + message_body = { "supplier": "test_supplier", "fhir_json": imms.json(), @@ -64,9 +64,9 @@ def test_send_request_to_dynamo_create_badrequest(self): } self.mock_service.create_immunization.return_value = create_result - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, create_result) self.mock_service.create_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -74,8 +74,8 @@ def test_send_request_to_dynamo_create_badrequest(self): vax_type=message_body['vax_type'], table=self.mock_table, is_present=True - ) - + ) + def test_send_request_to_dynamo_create_duplicate(self): """it should not create the Immunization since its a duplicate record""" @@ -89,10 +89,10 @@ def test_send_request_to_dynamo_create_duplicate(self): "operation_requested": "CREATE" } - self.mock_service.create_immunization.return_value = create_result - + self.mock_service.create_immunization.return_value = create_result + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, create_result) self.mock_service.create_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -115,10 +115,10 @@ def test_send_request_to_dynamo_create_unhandled_error(self): "operation_requested": "CREATE" } - self.mock_service.create_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") - + self.mock_service.create_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.create_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -126,7 +126,7 @@ def test_send_request_to_dynamo_create_unhandled_error(self): vax_type=message_body['vax_type'], table=self.mock_table, is_present=True - ) + ) class TestUpdateImmunizationBatchController(unittest.TestCase): @@ -152,9 +152,9 @@ def test_send_request_to_dynamo_update_success(self): } self.mock_service.update_immunization.return_value = imms_id - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, imms_id) self.mock_service.update_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -178,9 +178,9 @@ def test_send_request_to_dynamo_update_badrequest(self): } self.mock_service.update_immunization.return_value = update_result - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.update_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -188,8 +188,8 @@ def test_send_request_to_dynamo_update_badrequest(self): vax_type=message_body['vax_type'], table=self.mock_table, is_present=True - ) - + ) + def test_send_request_to_dynamo_update_resource_not_found(self): """it should not update the Immunization since no resource found for the record""" @@ -203,10 +203,10 @@ def test_send_request_to_dynamo_update_resource_not_found(self): "operation_requested": "UPDATE" } - self.mock_service.update_immunization.return_value = update_result - + self.mock_service.update_immunization.return_value = update_result + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.update_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -229,10 +229,10 @@ def test_send_request_to_dynamo_update_unhandled_error(self): "operation_requested": "UPDATE" } - self.mock_service.update_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") - + self.mock_service.update_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.update_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -240,8 +240,8 @@ def test_send_request_to_dynamo_update_unhandled_error(self): vax_type=message_body['vax_type'], table=self.mock_table, is_present=True - ) - + ) + class TestDeleteImmunizationBatchController(unittest.TestCase): @@ -267,9 +267,9 @@ def test_send_request_to_dynamo_delete_success(self): } self.mock_service.delete_immunization.return_value = imms_id - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, imms_id) self.mock_service.delete_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -293,9 +293,9 @@ def test_send_request_to_dynamo_delete_badrequest(self): } self.mock_service.delete_immunization.return_value = update_result - + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.delete_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -303,8 +303,8 @@ def test_send_request_to_dynamo_delete_badrequest(self): vax_type=message_body['vax_type'], table=self.mock_table, is_present=True - ) - + ) + def test_send_request_to_dynamo_delete_resource_not_found(self): """it should not delete the Immunization since no resource found for the record""" @@ -318,10 +318,10 @@ def test_send_request_to_dynamo_delete_resource_not_found(self): "operation_requested": "DELETE" } - self.mock_service.delete_immunization.return_value = update_result - + self.mock_service.delete_immunization.return_value = update_result + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.delete_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -344,10 +344,10 @@ def test_send_request_to_dynamo_delete_unhandled_error(self): "operation_requested": "DELETE" } - self.mock_service.delete_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") - + self.mock_service.delete_immunization.return_value = UnhandledResponseError("Non-200 response from dynamodb", "connection timeout") + result = self.controller.send_request_to_dynamo(message_body, self.mock_table, True) - + self.assertEqual(result, update_result) self.mock_service.delete_immunization.assert_called_once_with( immunization=message_body['fhir_json'], @@ -356,7 +356,7 @@ def test_send_request_to_dynamo_delete_unhandled_error(self): table=self.mock_table, is_present=True ) - + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/backend/tests/test_fhir_controller.py b/backend/tests/controller/test_fhir_controller.py similarity index 97% rename from backend/tests/test_fhir_controller.py rename to backend/tests/controller/test_fhir_controller.py index f25c7fdcf..caad9683b 100644 --- a/backend/tests/test_fhir_controller.py +++ b/backend/tests/controller/test_fhir_controller.py @@ -10,9 +10,11 @@ from unittest.mock import create_autospec, ANY, patch, Mock from urllib.parse import urlencode import urllib.parse -from fhir_controller import FhirController -from fhir_repository import ImmunizationRepository -from fhir_service import FhirService, UpdateOutcome + +from controller.aws_apig_response_utils import create_response +from controller.fhir_controller import FhirController +from repository.fhir_repository import ImmunizationRepository +from service.fhir_service import FhirService, UpdateOutcome from models.errors import ( ResourceNotFoundError, UnhandledResponseError, @@ -22,9 +24,9 @@ UnauthorizedVaxError, IdentifierDuplicationError, ) -from tests.utils.immunization_utils import create_covid_19_immunization +from testing_utils.immunization_utils import create_covid_19_immunization from parameter_parser import patient_identifier_system, process_search_params -from tests.utils.generic_utils import load_json_data +from testing_utils.generic_utils import load_json_data class TestFhirControllerBase(unittest.TestCase): """Base class for all tests to set up common fixtures""" @@ -52,7 +54,7 @@ def setUp(self): def test_create_response(self): """it should return application/fhir+json with correct status code""" body = {"message": "a body"} - res = self.controller.create_response(42, body) + res = create_response(42, body) headers = res["headers"] self.assertEqual(res["statusCode"], 42) @@ -65,7 +67,7 @@ def test_create_response(self): self.assertDictEqual(json.loads(res["body"]), body) def test_no_body_no_header(self): - res = self.controller.create_response(42) + res = create_response(42) self.assertEqual(res["statusCode"], 42) self.assertDictEqual(res["headers"], {}) self.assertTrue("body" not in res) @@ -671,7 +673,7 @@ def test_get_imms_by_id(self): """it should return Immunization resource if it exists""" # Given imms_id = "a-id" - self.service.get_immunization_by_id.return_value = Immunization.construct() + self.service.get_immunization_and_version_by_id.return_value = (Immunization.construct(), "1") lambda_event = { "headers": {"SupplierSystem": "test"}, "pathParameters": {"id": imms_id}, @@ -680,17 +682,37 @@ def test_get_imms_by_id(self): # When response = self.controller.get_immunization_by_id(lambda_event) # Then - self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test") + self.service.get_immunization_and_version_by_id.assert_called_once_with(imms_id, "test") self.assertEqual(response["statusCode"], 200) body = json.loads(response["body"]) self.assertEqual(body["resourceType"], "Immunization") + self.assertEqual(response["headers"]["E-Tag"], "1") - def test_get_imms_by_id_unauthorised_vax_error(self): + def test_get_imms_by_id_returns_unauthorized_when_supplier_header_missing(self): """it should return Immunization resource if it exists""" # Given + imms_id = "foo-123" + lambda_event = { + "headers": {"missing": "required supplier header"}, + "pathParameters": {"id": imms_id}, + } + + # When + response = self.controller.get_immunization_by_id(lambda_event) + # Then + self.service.get_immunization_and_version_by_id.assert_not_called() + + self.assertEqual(response["statusCode"], 403) + body = json.loads(response["body"]) + self.assertEqual(body["resourceType"], "OperationOutcome") + self.assertEqual(body["issue"][0]["code"], "forbidden") + + def test_get_imms_by_id_unauthorised_vax_error(self): + """it should return a 403 error is the service layer throws an UnauthorizedVaxError""" + # Given imms_id = "a-id" - self.service.get_immunization_by_id.side_effect = UnauthorizedVaxError + self.service.get_immunization_and_version_by_id.side_effect = UnauthorizedVaxError lambda_event = { "headers": {"SupplierSystem": "test"}, "pathParameters": {"id": imms_id}, @@ -700,12 +722,18 @@ def test_get_imms_by_id_unauthorised_vax_error(self): response = self.controller.get_immunization_by_id(lambda_event) # Then self.assertEqual(response["statusCode"], 403) + body = json.loads(response["body"]) + self.assertEqual(body["resourceType"], "OperationOutcome") + self.assertEqual(body["issue"][0]["code"], "forbidden") def test_not_found(self): """it should return not-found OperationOutcome if it doesn't exist""" # Given imms_id = "a-non-existing-id" - self.service.get_immunization_by_id.return_value = None + self.service.get_immunization_and_version_by_id.side_effect = ResourceNotFoundError( + resource_type="Immunization", + resource_id=imms_id + ) lambda_event = { "headers": {"SupplierSystem": "test"}, "pathParameters": {"id": imms_id}, @@ -715,7 +743,7 @@ def test_not_found(self): response = self.controller.get_immunization_by_id(lambda_event) # Then - self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test") + self.service.get_immunization_and_version_by_id.assert_called_once_with(imms_id, "test") self.assertEqual(response["statusCode"], 404) body = json.loads(response["body"]) @@ -728,7 +756,7 @@ def test_validate_imms_id(self): response = self.controller.get_immunization_by_id(invalid_id) - self.assertEqual(self.service.get_immunization_by_id.call_count, 0) + self.assertEqual(self.service.get_immunization_and_version_by_id.call_count, 0) self.assertEqual(response["statusCode"], 400) outcome = json.loads(response["body"]) self.assertEqual(outcome["resourceType"], "OperationOutcome") @@ -799,7 +827,7 @@ def test_malformed_resource(self): } response = self.controller.create_immunization(aws_event) - self.assertEqual(self.service.get_immunization_by_id.call_count, 0) + self.assertEqual(self.service.get_immunization_and_version_by_id.call_count, 0) self.assertEqual(response["statusCode"], 400) outcome = json.loads(response["body"]) self.assertEqual(outcome["resourceType"], "OperationOutcome") @@ -1072,28 +1100,6 @@ def test_update_deletedat_immunization_without_version(self): self.assertEqual(response["statusCode"], 200) self.assertEqual(response["headers"]["E-Tag"], 2) - def test_update_record_exists(self): - """it should return not-found OperationOutcome if ID doesn't exist""" - # Given - - imms_id = "a-non-existing-id" - self.service.get_immunization_by_id.return_value = None - lambda_event = { - "headers": {"E-Tag": 1, "SupplierSystem": "Test"}, - "pathParameters": {"id": imms_id}, - } - - # When - response = self.controller.get_immunization_by_id(lambda_event) - - # Then - self.service.get_immunization_by_id.assert_called_once_with(imms_id, "Test") - - self.assertEqual(response["statusCode"], 404) - body = json.loads(response["body"]) - self.assertEqual(body["resourceType"], "OperationOutcome") - self.assertEqual(body["issue"][0]["code"], "not-found") - def test_validation_error(self): """it should return 400 if Immunization is invalid""" # Given @@ -1327,7 +1333,7 @@ def test_validate_imms_id(self): response = self.controller.delete_immunization(invalid_id) - self.assertEqual(self.service.get_immunization_by_id.call_count, 0) + self.assertEqual(self.service.get_immunization_and_version_by_id.call_count, 0) self.assertEqual(response["statusCode"], 400) outcome = json.loads(response["body"]) self.assertEqual(outcome["resourceType"], "OperationOutcome") @@ -1660,7 +1666,7 @@ def test_post_search_immunizations_for_unauthorized_vaccine_type_search_403(self body = json.loads(response["body"]) self.assertEqual(body["resourceType"], "OperationOutcome") - @patch("fhir_controller.process_search_params", wraps=process_search_params) + @patch("controller.fhir_controller.process_search_params", wraps=process_search_params) def test_uses_parameter_parser(self, process_search_params: Mock): self.mock_redis_client.hkeys.return_value = self.MOCK_REDIS_V2D_HKEYS lambda_event = { @@ -1679,7 +1685,7 @@ def test_uses_parameter_parser(self, process_search_params: Mock): } ) - @patch("fhir_controller.process_search_params") + @patch("controller.fhir_controller.process_search_params") def test_search_immunizations_returns_400_on_ParameterException_from_parameter_parser( self, process_search_params: Mock ): diff --git a/backend/tests/models/__init__.py b/backend/tests/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/models/utils/__init__.py b/backend/tests/models/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/utils/test_generic_utils.py b/backend/tests/models/utils/test_generic_utils.py similarity index 98% rename from backend/tests/utils/test_generic_utils.py rename to backend/tests/models/utils/test_generic_utils.py index f7014fc79..852bc8598 100644 --- a/backend/tests/utils/test_generic_utils.py +++ b/backend/tests/models/utils/test_generic_utils.py @@ -2,7 +2,7 @@ import unittest from src.models.utils.generic_utils import form_json -from tests.utils.generic_utils import load_json_data, format_date_types +from testing_utils.generic_utils import load_json_data, format_date_types import unittest from datetime import datetime, date diff --git a/backend/tests/repository/__init__.py b/backend/tests/repository/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/test_fhir_batch_repository.py b/backend/tests/repository/test_fhir_batch_repository.py similarity index 94% rename from backend/tests/test_fhir_batch_repository.py rename to backend/tests/repository/test_fhir_batch_repository.py index 16fe19a56..e3a64c2a9 100644 --- a/backend/tests/test_fhir_batch_repository.py +++ b/backend/tests/repository/test_fhir_batch_repository.py @@ -7,8 +7,8 @@ from moto import mock_aws from uuid import uuid4 from models.errors import IdentifierDuplicationError, ResourceNotFoundError, UnhandledResponseError, ResourceFoundError -from fhir_batch_repository import ImmunizationBatchRepository, create_table -from tests.utils.immunization_utils import create_covid_19_immunization_dict +from repository.fhir_batch_repository import ImmunizationBatchRepository, create_table +from testing_utils.immunization_utils import create_covid_19_immunization_dict imms_id = str(uuid4()) @@ -18,7 +18,7 @@ def _make_immunization_pk(_id): @mock_aws class TestImmunizationBatchRepository(unittest.TestCase): - + def setUp(self): os.environ["DYNAMODB_TABLE_NAME"] = "test-immunization-table" self.dynamodb = boto3.resource("dynamodb", region_name="eu-west-2") @@ -37,8 +37,8 @@ def setUp(self): def tearDown(self): patch.stopall() -class TestCreateImmunization(TestImmunizationBatchRepository): - +class TestCreateImmunization(TestImmunizationBatchRepository): + def modify_immunization(self, remove_nhs): """Modify the immunization object by removing NHS number if required""" if remove_nhs: @@ -78,8 +78,8 @@ def test_create_immunization_with_nhs_number(self): def test_create_immunization_without_nhs_number(self): """Test creating Immunization without NHS number.""" - - self.create_immunization_test_logic(is_present=False, remove_nhs=True) + + self.create_immunization_test_logic(is_present=False, remove_nhs=True) def test_create_immunization_duplicate(self): @@ -92,7 +92,7 @@ def test_create_immunization_duplicate(self): }) with self.assertRaises(IdentifierDuplicationError): self.repository.create_immunization(self.immunization, "supplier", "vax-type", self.table, False) - self.table.put_item.assert_not_called() + self.table.put_item.assert_not_called() def test_create_should_catch_dynamo_error(self): """it should throw UnhandledResponse when the response from dynamodb can't be handled""" @@ -102,7 +102,7 @@ def test_create_should_catch_dynamo_error(self): self.table.put_item = MagicMock(return_value=response) with self.assertRaises(UnhandledResponseError) as e: self.repository.create_immunization(self.immunization, "supplier", "vax-type", self.table, False) - self.assertDictEqual(e.exception.response, response) + self.assertDictEqual(e.exception.response, response) def test_create_immunization_unhandled_error(self): @@ -119,10 +119,10 @@ def test_create_immunization_conditionalcheckfailedexception_error(self): with unittest.mock.patch.object(self.table, 'put_item', side_effect=botocore.exceptions.ClientError({"Error": {"Code": "ConditionalCheckFailedException"}}, "PutItem")): with self.assertRaises(ResourceFoundError): - self.repository.create_immunization(self.immunization, "supplier", "vax-type", self.table, False) - + self.repository.create_immunization(self.immunization, "supplier", "vax-type", self.table, False) + -class TestUpdateImmunization(TestImmunizationBatchRepository): +class TestUpdateImmunization(TestImmunizationBatchRepository): def test_update_immunization(self): """it should update Immunization record""" @@ -165,7 +165,7 @@ def test_update_immunization(self): }, "expected_extra_values": {} } - ] + ] for is_present in [True, False]: for case in test_cases: with self.subTest(is_present=is_present, case=case): @@ -183,7 +183,7 @@ def test_update_immunization(self): ":supplier_system": "supplier" } expected_values.update(case["expected_extra_values"]) - + self.table.update_item.assert_called_with( Key={"PK": _make_immunization_pk(imms_id)}, UpdateExpression=ANY, @@ -193,7 +193,7 @@ def test_update_immunization(self): ConditionExpression=ANY, ) self.assertEqual(response, f'Immunization#{self.immunization["id"]}') - + def test_update_immunization_not_found(self): """it should not update Immunization since the imms id not found""" @@ -218,7 +218,7 @@ def test_update_should_catch_dynamo_error(self): ) with self.assertRaises(UnhandledResponseError) as e: self.repository.update_immunization(self.immunization, "supplier", "vax-type", self.table, False) - self.assertDictEqual(e.exception.response, response) + self.assertDictEqual(e.exception.response, response) def test_update_immunization_unhandled_error(self): """it should throw UnhandledResponse when the response from dynamodb can't be handled""" @@ -235,7 +235,7 @@ def test_update_immunization_unhandled_error(self): }] } ) - self.repository.update_immunization(self.immunization, "supplier", "vax-type", self.table, False) + self.repository.update_immunization(self.immunization, "supplier", "vax-type", self.table, False) self.assertDictEqual(e.exception.response, response) def test_update_immunization_conditionalcheckfailedexception_error(self): @@ -252,9 +252,9 @@ def test_update_immunization_conditionalcheckfailedexception_error(self): }] } ) - self.repository.update_immunization(self.immunization, "supplier", "vax-type", self.table, False) - -class TestDeleteImmunization(TestImmunizationBatchRepository): + self.repository.update_immunization(self.immunization, "supplier", "vax-type", self.table, False) + +class TestDeleteImmunization(TestImmunizationBatchRepository): def test_delete_immunization(self): """it should delete Immunization record""" @@ -276,7 +276,7 @@ def test_delete_immunization(self): ReturnValues=ANY, ConditionExpression=ANY, ) - self.assertEqual(response, f'Immunization#{self.immunization ["id"]}') + self.assertEqual(response, f'Immunization#{self.immunization ["id"]}') def test_delete_immunization_not_found(self): """it should not delete Immunization since the imms id not found""" @@ -302,7 +302,7 @@ def test_delete_should_catch_dynamo_error(self): ) with self.assertRaises(UnhandledResponseError) as e: self.repository.delete_immunization(self.immunization, "supplier", "vax-type", self.table, False) - self.assertDictEqual(e.exception.response, response) + self.assertDictEqual(e.exception.response, response) def test_delete_immunization_unhandled_error(self): """it should throw UnhandledResponse when the response from dynamodb can't be handled""" @@ -319,7 +319,7 @@ def test_delete_immunization_unhandled_error(self): }] } ) - self.repository.delete_immunization(self.immunization, "supplier", "vax-type", self.table, False) + self.repository.delete_immunization(self.immunization, "supplier", "vax-type", self.table, False) self.assertDictEqual(e.exception.response, response) def test_delete_immunization_conditionalcheckfailedexception_error(self): diff --git a/backend/tests/test_fhir_repository.py b/backend/tests/repository/test_fhir_repository.py similarity index 97% rename from backend/tests/test_fhir_repository.py rename to backend/tests/repository/test_fhir_repository.py index fe9a1e1c9..46e49e419 100644 --- a/backend/tests/test_fhir_repository.py +++ b/backend/tests/repository/test_fhir_repository.py @@ -6,16 +6,15 @@ import botocore.exceptions from boto3.dynamodb.conditions import Attr, Key -from fhir_repository import ImmunizationRepository +from repository.fhir_repository import ImmunizationRepository from models.utils.validation_utils import get_vaccine_type from models.errors import ( ResourceNotFoundError, UnhandledResponseError, - IdentifierDuplicationError, - UnauthorizedVaxError + IdentifierDuplicationError ) -from tests.utils.generic_utils import update_target_disease_code -from tests.utils.immunization_utils import create_covid_19_immunization_dict +from testing_utils.generic_utils import update_target_disease_code +from testing_utils.immunization_utils import create_covid_19_immunization_dict def _make_immunization_pk(_id): return f"Immunization#{_id}" @@ -101,22 +100,22 @@ def tearDown(self): def test_get_immunization_by_id(self): """it should find an Immunization by id""" imms_id = "an-id" - resource = dict() - resource["Resource"] = {"foo": "bar"} - resource["Version"] = 1 + expected_resource = {"foo": "bar"} + expected_version = "1" self.table.get_item = MagicMock( return_value={ "Item": { - "Resource": json.dumps({"foo": "bar"}), - "Version": 1, + "Resource": json.dumps(expected_resource), + "Version": expected_version, "PatientSK": "COVID19#2516525251", } } ) - imms = self.repository.get_immunization_by_id(imms_id) + immunisation, version = self.repository.get_immunization_and_version_by_id(imms_id) # Validate the results - self.assertDictEqual(resource, imms) + self.assertDictEqual(expected_resource, immunisation) + self.assertEqual(version, expected_version) self.table.get_item.assert_called_once_with(Key={"PK": _make_immunization_pk(imms_id)}) def test_immunization_not_found(self): @@ -124,8 +123,9 @@ def test_immunization_not_found(self): imms_id = "non-existent-id" self.table.get_item = MagicMock(return_value={}) - imms = self.repository.get_immunization_by_id(imms_id) + imms, version = self.repository.get_immunization_and_version_by_id(imms_id) self.assertIsNone(imms) + self.assertIsNone(version) def _make_a_patient(nhs_number="1234567890") -> dict: @@ -466,8 +466,9 @@ def test_get_deleted_immunization(self): imms_id = "a-deleted-id" self.table.get_item = MagicMock(return_value={"Item": {"Resource": "{}", "DeletedAt": time.time()}}) - imms = self.repository.get_immunization_by_id(imms_id) + imms, version = self.repository.get_immunization_and_version_by_id(imms_id) self.assertIsNone(imms) + self.assertIsNone(version) def test_delete_immunization(self): """it should logical delete Immunization by setting DeletedAt attribute""" diff --git a/backend/tests/service/__init__.py b/backend/tests/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/test_fhir_batch_service.py b/backend/tests/service/test_fhir_batch_service.py similarity index 82% rename from backend/tests/test_fhir_batch_service.py rename to backend/tests/service/test_fhir_batch_service.py index 94a2cd14b..28303ce54 100644 --- a/backend/tests/test_fhir_batch_service.py +++ b/backend/tests/service/test_fhir_batch_service.py @@ -2,11 +2,11 @@ import uuid from copy import deepcopy from unittest.mock import Mock, create_autospec, patch -from tests.utils.immunization_utils import create_covid_19_immunization_dict_no_id +from testing_utils.immunization_utils import create_covid_19_immunization_dict_no_id from models.errors import CustomValidationError from models.fhir_immunization import ImmunizationValidator -from fhir_batch_repository import ImmunizationBatchRepository -from fhir_batch_service import ImmunizationBatchService +from repository.fhir_batch_repository import ImmunizationBatchRepository +from service.fhir_batch_service import ImmunizationBatchService class TestFhirBatchServiceBase(unittest.TestCase): @@ -51,10 +51,10 @@ def test_create_immunization_valid(self): imms_id = str(uuid.uuid4()) self.mock_repo.create_immunization.return_value = imms_id result = self.service.create_immunization( - immunization=create_covid_19_immunization_dict_no_id(), - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=create_covid_19_immunization_dict_no_id(), + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertEqual(result, imms_id) @@ -67,18 +67,18 @@ def test_create_immunization_pre_validation_error(self): expected_msg = "Validation errors: status must be one of the following: completed" with self.assertRaises(CustomValidationError) as error: self.pre_validate_fhir_service.create_immunization( - immunization=imms, - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=imms, + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertTrue(expected_msg in error.exception.message) - self.mock_repo.create_immunization.assert_not_called() + self.mock_repo.create_immunization.assert_not_called() def test_create_immunization_post_validation_error(self): """it should return error since it got failed in initial validation""" - + valid_imms = create_covid_19_immunization_dict_no_id() bad_target_disease_imms = deepcopy(valid_imms) bad_target_disease_imms["protocolApplied"][0]["targetDisease"][0]["coding"][0]["code"] = "bad-code" @@ -86,14 +86,14 @@ def test_create_immunization_post_validation_error(self): self.mock_redis_client.hget.return_value = None # Reset mock for invalid cases with self.assertRaises(CustomValidationError) as error: self.pre_validate_fhir_service.create_immunization( - immunization=bad_target_disease_imms, - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=bad_target_disease_imms, + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertTrue(expected_msg in error.exception.message) - self.mock_repo.create_immunization.assert_not_called() + self.mock_repo.create_immunization.assert_not_called() class TestUpdateImmunizationBatchService(TestFhirBatchServiceBase): @@ -122,10 +122,10 @@ def test_update_immunization_valid(self): imms_id = str(uuid.uuid4()) self.mock_repo.update_immunization.return_value = imms_id result = self.service.update_immunization( - immunization=create_covid_19_immunization_dict_no_id(), - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=create_covid_19_immunization_dict_no_id(), + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertEqual(result, imms_id) @@ -138,14 +138,14 @@ def test_update_immunization_pre_validation_error(self): expected_msg = "Validation errors: status must be one of the following: completed" with self.assertRaises(CustomValidationError) as error: self.pre_validate_fhir_service.update_immunization( - immunization=imms, - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=imms, + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertTrue(expected_msg in error.exception.message) - self.mock_repo.update_immunization.assert_not_called() + self.mock_repo.update_immunization.assert_not_called() def test_update_immunization_post_validation_error(self): """it should return error since it got failed in initial validation""" @@ -158,10 +158,10 @@ def test_update_immunization_post_validation_error(self): expected_msg = "protocolApplied[0].targetDisease[*].coding[?(@.system=='http://snomed.info/sct')].code - ['bad-code'] is not a valid combination of disease codes for this service" with self.assertRaises(CustomValidationError) as error: self.pre_validate_fhir_service.update_immunization( - immunization=bad_target_disease_imms, - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=bad_target_disease_imms, + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) self.assertTrue(expected_msg in error.exception.message) @@ -185,15 +185,15 @@ def test_delete_immunization_valid(self): imms_id = str(uuid.uuid4()) self.mock_repo.delete_immunization.return_value = imms_id result = self.service.delete_immunization( - immunization=create_covid_19_immunization_dict_no_id(), - supplier_system="test_supplier", - vax_type="test_vax", - table=self.mock_table, + immunization=create_covid_19_immunization_dict_no_id(), + supplier_system="test_supplier", + vax_type="test_vax", + table=self.mock_table, is_present=True ) - self.assertEqual(result, imms_id) + self.assertEqual(result, imms_id) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/backend/tests/test_fhir_service.py b/backend/tests/service/test_fhir_service.py similarity index 95% rename from backend/tests/test_fhir_service.py rename to backend/tests/service/test_fhir_service.py index ecf39a436..255f8c827 100644 --- a/backend/tests/test_fhir_service.py +++ b/backend/tests/service/test_fhir_service.py @@ -3,7 +3,7 @@ import datetime import unittest import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from copy import deepcopy from unittest.mock import create_autospec, patch from decimal import Decimal @@ -13,20 +13,20 @@ from authorisation.api_operation_code import ApiOperationCode from authorisation.authoriser import Authoriser -from fhir_repository import ImmunizationRepository -from fhir_service import FhirService, UpdateOutcome, get_service_url +from repository.fhir_repository import ImmunizationRepository +from service.fhir_service import FhirService, UpdateOutcome, get_service_url from models.errors import InvalidPatientId, CustomValidationError, UnauthorizedVaxError, ResourceNotFoundError from models.fhir_immunization import ImmunizationValidator from models.utils.generic_utils import get_contained_patient from pydantic import ValidationError from pydantic.error_wrappers import ErrorWrapper -from tests.utils.immunization_utils import ( +from testing_utils.immunization_utils import ( create_covid_19_immunization, create_covid_19_immunization_dict, create_covid_19_immunization_dict_no_id, VALID_NHS_NUMBER, ) -from tests.utils.generic_utils import load_json_data +from testing_utils.generic_utils import load_json_data from constants import NHS_NUMBER_USED_IN_SAMPLE_DATA class TestFhirServiceBase(unittest.TestCase): @@ -205,29 +205,30 @@ def test_get_immunization_by_id(self): imms_id = "an-id" self.mock_redis_client.hget.return_value = "COVID-19" self.authoriser.authorise.return_value = True - self.imms_repo.get_immunization_by_id.return_value = {"Resource": create_covid_19_immunization(imms_id).dict()} + self.imms_repo.get_immunization_and_version_by_id.return_value = (create_covid_19_immunization(imms_id).dict(), "") # When - service_resp = self.fhir_service.get_immunization_by_id(imms_id, "Test Supplier") - act_imms = service_resp["Resource"] + immunisation, version = self.fhir_service.get_immunization_and_version_by_id(imms_id, "Test Supplier") # Then self.authoriser.authorise.assert_called_once_with("Test Supplier", ApiOperationCode.READ, {"COVID-19"}) - self.imms_repo.get_immunization_by_id.assert_called_once_with(imms_id) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(imms_id) - self.assertEqual(act_imms.id, imms_id) + self.assertEqual(immunisation.id, imms_id) + self.assertEqual(version, "") def test_immunization_not_found(self): """it should return None if Immunization doesn't exist""" - imms_id = "none-existent-id" - self.imms_repo.get_immunization_by_id.return_value = None + imms_id = "non-existent-id" + self.imms_repo.get_immunization_and_version_by_id.return_value = None, None # When - act_imms = self.fhir_service.get_immunization_by_id(imms_id, "Test Supplier") + with self.assertRaises(ResourceNotFoundError) as error: + self.fhir_service.get_immunization_and_version_by_id(imms_id, "Test Supplier") # Then - self.imms_repo.get_immunization_by_id.assert_called_once_with(imms_id) - self.assertEqual(act_imms, None) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(imms_id) + self.assertEqual("Immunization resource does not exist. ID: non-existent-id", str(error.exception)) def test_get_immunization_by_id_patient_not_restricted(self): """ @@ -239,18 +240,19 @@ def test_get_immunization_by_id_patient_not_restricted(self): immunization_data = load_json_data("completed_covid19_immunization_event.json") self.mock_redis_client.hget.return_value = "COVID-19" self.authoriser.authorise.return_value = True - self.imms_repo.get_immunization_by_id.return_value = {"Resource": immunization_data} + self.imms_repo.get_immunization_and_version_by_id.return_value = (immunization_data, "2") expected_imms = load_json_data("completed_covid19_immunization_event_for_read.json") expected_output = Immunization.parse_obj(expected_imms) # When - actual_output = self.fhir_service.get_immunization_by_id(imms_id, "Test Supplier") + actual_output, version = self.fhir_service.get_immunization_and_version_by_id(imms_id, "Test Supplier") # Then self.authoriser.authorise.assert_called_once_with("Test Supplier", ApiOperationCode.READ, {"COVID-19"}) - self.imms_repo.get_immunization_by_id.assert_called_once_with(imms_id) - self.assertEqual(actual_output["Resource"], expected_output) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(imms_id) + self.assertEqual(actual_output, expected_output) + self.assertEqual(version, "2") def test_pre_validation_failed(self): """it should throw exception if Immunization is not valid""" @@ -282,15 +284,15 @@ def test_unauthorised_error_raised_when_user_lacks_permissions(self): imms_id = "an-id" self.mock_redis_client.hget.return_value = "COVID-19" self.authoriser.authorise.return_value = False - self.imms_repo.get_immunization_by_id.return_value = {"Resource": create_covid_19_immunization(imms_id).dict()} + self.imms_repo.get_immunization_and_version_by_id.return_value = (create_covid_19_immunization(imms_id).dict(), 1) with self.assertRaises(UnauthorizedVaxError): # When - self.fhir_service.get_immunization_by_id(imms_id, "Test Supplier") + self.fhir_service.get_immunization_and_version_by_id(imms_id, "Test Supplier") # Then self.authoriser.authorise.assert_called_once_with("Test Supplier", ApiOperationCode.READ, {"COVID-19"}) - self.imms_repo.get_immunization_by_id.assert_called_once_with(imms_id) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(imms_id) def test_post_validation_failed_get_invalid_target_disease(self): @@ -727,14 +729,14 @@ def test_delete_immunization(self): imms = json.loads(create_covid_19_immunization(self.TEST_IMMUNISATION_ID).json()) self.mock_redis_client.hget.return_value = "COVID19" self.authoriser.authorise.return_value = True - self.imms_repo.get_immunization_by_id.return_value = {"Resource": imms} + self.imms_repo.get_immunization_and_version_by_id.return_value = (imms, "1") self.imms_repo.delete_immunization.return_value = imms # When act_imms = self.fhir_service.delete_immunization(self.TEST_IMMUNISATION_ID, "Test") # Then - self.imms_repo.get_immunization_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) self.imms_repo.delete_immunization.assert_called_once_with(self.TEST_IMMUNISATION_ID, "Test") self.authoriser.authorise.assert_called_once_with("Test", ApiOperationCode.DELETE, {"COVID19"}) self.assertIsInstance(act_imms, Immunization) @@ -742,28 +744,28 @@ def test_delete_immunization(self): def test_delete_immunization_throws_not_found_exception_if_does_not_exist(self): """it should raise a ResourceNotFound exception if the immunisation does not exist""" - self.imms_repo.get_immunization_by_id.return_value = None + self.imms_repo.get_immunization_and_version_by_id.return_value = (None, None) # When with self.assertRaises(ResourceNotFoundError): self.fhir_service.delete_immunization(self.TEST_IMMUNISATION_ID, "Test") # Then - self.imms_repo.get_immunization_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) self.imms_repo.delete_immunization.assert_not_called() def test_delete_immunization_throws_authorisation_exception_if_does_not_have_required_permissions(self): imms = json.loads(create_covid_19_immunization(self.TEST_IMMUNISATION_ID).json()) self.mock_redis_client.hget.return_value = "FLU" self.authoriser.authorise.return_value = False - self.imms_repo.get_immunization_by_id.return_value = {"Resource": imms} + self.imms_repo.get_immunization_and_version_by_id.return_value = (imms, "1") # When with self.assertRaises(UnauthorizedVaxError): self.fhir_service.delete_immunization(self.TEST_IMMUNISATION_ID, "Test") # Then - self.imms_repo.get_immunization_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) + self.imms_repo.get_immunization_and_version_by_id.assert_called_once_with(self.TEST_IMMUNISATION_ID) self.imms_repo.delete_immunization.assert_not_called() self.authoriser.authorise.assert_called_once_with("Test", ApiOperationCode.DELETE, {"FLU"}) diff --git a/backend/tests/test_create_imms.py b/backend/tests/test_create_imms.py index 232889d01..3bf2bc585 100644 --- a/backend/tests/test_create_imms.py +++ b/backend/tests/test_create_imms.py @@ -3,7 +3,7 @@ from unittest.mock import create_autospec, patch from create_imms_handler import create_immunization -from fhir_controller import FhirController +from controller.fhir_controller import FhirController from models.errors import Severity, Code, create_operation_outcome from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE @@ -15,7 +15,7 @@ def setUp(self): self.mock_logger_info = self.logger_info_patcher.start() self.logger_exception_patcher = patch("logging.Logger.exception") self.mock_logger_exception = self.logger_exception_patcher.start() - + def tearDown(self): patch.stopall() diff --git a/backend/tests/test_delete_imms.py b/backend/tests/test_delete_imms.py index 4bc55eca5..513549f98 100644 --- a/backend/tests/test_delete_imms.py +++ b/backend/tests/test_delete_imms.py @@ -3,7 +3,7 @@ from unittest.mock import create_autospec, patch from delete_imms_handler import delete_immunization -from fhir_controller import FhirController +from controller.fhir_controller import FhirController from models.errors import Severity, Code, create_operation_outcome from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE diff --git a/backend/tests/test_filter.py b/backend/tests/test_filter.py index aac2082ac..e18a6a4cf 100644 --- a/backend/tests/test_filter.py +++ b/backend/tests/test_filter.py @@ -14,7 +14,7 @@ replace_address_postal_codes, replace_organization_values, ) -from tests.utils.generic_utils import load_json_data +from testing_utils.generic_utils import load_json_data class TestFilter(unittest.TestCase): @@ -153,4 +153,4 @@ def test_filter_search(self): ) expected_output["patient"]["reference"] = patient_full_url - self.assertEqual(Filter.search(unfiltered_imms, patient_full_url), expected_output) \ No newline at end of file + self.assertEqual(Filter.search(unfiltered_imms, patient_full_url), expected_output) diff --git a/backend/tests/test_forwarding_batch_lambda.py b/backend/tests/test_forwarding_batch_lambda.py index 8ffb58c57..b83e28f9e 100644 --- a/backend/tests/test_forwarding_batch_lambda.py +++ b/backend/tests/test_forwarding_batch_lambda.py @@ -2,7 +2,7 @@ import os from typing import Optional from unittest import TestCase -from unittest.mock import patch, MagicMock, call, ANY +from unittest.mock import patch, MagicMock, ANY from boto3 import resource as boto3_resource from moto import mock_aws from models.errors import ( @@ -17,11 +17,10 @@ import copy import json -from utils.test_utils_for_batch import ForwarderValues, MockFhirImmsResources +from testing_utils.test_utils_for_batch import ForwarderValues, MockFhirImmsResources with patch.dict("os.environ", ForwarderValues.MOCK_ENVIRONMENT_DICT): - from forwarding_batch_lambda import forward_lambda_handler, create_diagnostics_dictionary, forward_request_to_dynamo, \ - QUEUE_URL + from forwarding_batch_lambda import forward_lambda_handler, create_diagnostics_dictionary @mock_aws diff --git a/backend/tests/test_get_imms.py b/backend/tests/test_get_imms.py index 88607b6fa..e91ba3ac0 100644 --- a/backend/tests/test_get_imms.py +++ b/backend/tests/test_get_imms.py @@ -1,11 +1,8 @@ -import json import unittest from unittest.mock import create_autospec -from fhir_controller import FhirController +from controller.fhir_controller import FhirController from get_imms_handler import get_immunization_by_id -from models.errors import Severity, Code, create_operation_outcome -from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE @@ -26,26 +23,3 @@ def test_get_immunization_by_id(self): # Then self.controller.get_immunization_by_id.assert_called_once_with(lambda_event) self.assertDictEqual(exp_res, act_res) - - def test_get_handle_exception(self): - """unhandled exceptions should result in 500""" - lambda_event = {"headers": {"id": "an-id"}, "pathParameters": {"id": "an-id"}} - error_msg = "an unhandled error" - self.controller.get_immunization_by_id.side_effect = Exception(error_msg) - - exp_error = create_operation_outcome( - resource_id=None, - severity=Severity.error, - code=Code.server_error, - diagnostics=GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE, - ) - - # When - act_res = get_immunization_by_id(lambda_event, self.controller) - - # Then - act_body = json.loads(act_res["body"]) - act_body["id"] = None - - self.assertDictEqual(act_body, exp_error) - self.assertEqual(act_res["statusCode"], 500) diff --git a/backend/tests/test_immunization_post_validator.py b/backend/tests/test_immunization_post_validator.py index 29d0f1cfc..2bd98d47e 100644 --- a/backend/tests/test_immunization_post_validator.py +++ b/backend/tests/test_immunization_post_validator.py @@ -9,14 +9,14 @@ from models.fhir_immunization import ImmunizationValidator -from tests.utils.generic_utils import ( +from testing_utils.generic_utils import ( # these have an underscore to avoid pytest collecting them as tests test_invalid_values_rejected as _test_invalid_values_rejected, load_json_data, ) -from tests.utils.mandation_test_utils import MandationTests -from tests.utils.values_for_tests import NameInstances -from tests.utils.generic_utils import update_contained_resource_field +from testing_utils.mandation_test_utils import MandationTests +from testing_utils.values_for_tests import NameInstances +from testing_utils.generic_utils import update_contained_resource_field class TestImmunizationModelPostValidationRules(unittest.TestCase): """Test immunization post validation rules on the FHIR model""" diff --git a/backend/tests/test_immunization_pre_validator.py b/backend/tests/test_immunization_pre_validator.py index ba5836148..adfe3ade2 100644 --- a/backend/tests/test_immunization_pre_validator.py +++ b/backend/tests/test_immunization_pre_validator.py @@ -7,10 +7,9 @@ from jsonpath_ng.ext import parse -from clients import redis_client from models.fhir_immunization import ImmunizationValidator from models.utils.generic_utils import get_generic_extension_value -from utils.generic_utils import ( +from testing_utils.generic_utils import ( # these have an underscore to avoid pytest collecting them as tests test_valid_values_accepted as _test_valid_values_accepted, test_invalid_values_rejected as _test_invalid_values_rejected, @@ -22,9 +21,8 @@ practitioner_name_given_field_location, practitioner_name_family_field_location, ) -from utils.pre_validation_test_utils import ValidatorModelTests -from utils.values_for_tests import ValidValues, InvalidValues -from models.constants import Constants +from testing_utils.pre_validation_test_utils import ValidatorModelTests +from testing_utils.values_for_tests import ValidValues, InvalidValues from models.fhir_immunization_pre_validators import PreValidators class TestImmunizationModelPreValidationRules(unittest.TestCase): @@ -37,9 +35,7 @@ def setUp(self): self.validator = ImmunizationValidator(add_post_validators=False) self.redis_patcher = patch("models.utils.validation_utils.redis_client") self.mock_redis_client = self.redis_patcher.start() - - - + def tearDown(self): patch.stopall() @@ -1173,13 +1169,13 @@ def test_pre_validate_dose_quantity_system(self): system_location = "doseQuantity.system" ValidatorModelTests.test_string_value(self, system_location, valid_strings_to_test=["http://unitsofmeasure.org"]) - + def test_pre_validate_dose_quantity_code(self): """Test pre_validate_dose_quantity_code accepts valid values and rejects invalid values""" code_location = "doseQuantity.code" ValidatorModelTests.test_string_value(self, code_location, valid_strings_to_test=["ABC123"]) - + def test_pre_validate_dose_quantity_system_and_code(self): """Test pre_validate_dose_quantity_system_and_code accepts valid values and rejects invalid values""" diff --git a/backend/tests/test_parameter_parser.py b/backend/tests/test_parameter_parser.py index d2aa3555f..98660acea 100644 --- a/backend/tests/test_parameter_parser.py +++ b/backend/tests/test_parameter_parser.py @@ -3,7 +3,7 @@ import datetime from unittest.mock import create_autospec, patch -from fhir_service import FhirService +from service.fhir_service import FhirService from models.errors import ParameterException from parameter_parser import ( date_from_key, diff --git a/backend/tests/test_search_imms.py b/backend/tests/test_search_imms.py index c29e6bafe..7e432758a 100644 --- a/backend/tests/test_search_imms.py +++ b/backend/tests/test_search_imms.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import create_autospec, patch -from fhir_controller import FhirController +from controller.fhir_controller import FhirController from models.errors import Severity, Code, create_operation_outcome from search_imms_handler import search_imms from pathlib import Path diff --git a/backend/tests/test_update_imms.py b/backend/tests/test_update_imms.py index 8befd591f..5c1afef88 100644 --- a/backend/tests/test_update_imms.py +++ b/backend/tests/test_update_imms.py @@ -1,8 +1,7 @@ -import json import unittest from unittest.mock import create_autospec, patch -from fhir_controller import FhirController +from controller.fhir_controller import FhirController from models.errors import Severity, Code, create_operation_outcome from update_imms_handler import update_imms from constants import GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE @@ -36,7 +35,7 @@ def test_update_immunization(self): self.controller.update_immunization.assert_called_once_with(lambda_event) self.assertDictEqual(exp_res, act_res) - @patch("update_imms_handler.FhirController.create_response") + @patch("update_imms_handler.create_response") def test_update_imms_exception(self, mock_create_response): """unhandled exceptions should result in 500""" lambda_event = {"pathParameters": {"id": "an-id"}} @@ -68,7 +67,7 @@ def test_update_imms_exception(self, mock_create_response): self.assertEqual(diagnostics, GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE) self.assertEqual(act_res, mock_response) - + def test_update_imms_with_duplicated_identifier_returns_error(self): """Should return an IdentifierDuplication error""" diff --git a/backend/tests/test_utils.py b/backend/tests/test_utils.py index 82a920d2d..72f9255a2 100644 --- a/backend/tests/test_utils.py +++ b/backend/tests/test_utils.py @@ -1,12 +1,11 @@ """Tests for generic utils""" import unittest -import json -from unittest.mock import patch, MagicMock +from unittest.mock import patch from copy import deepcopy from models.utils.validation_utils import convert_disease_codes_to_vaccine_type, get_vaccine_type -from utils.generic_utils import load_json_data, update_target_disease_code +from testing_utils.generic_utils import load_json_data, update_target_disease_code class TestGenericUtils(unittest.TestCase): diff --git a/backend/tests/test_validation_utils.py b/backend/tests/test_validation_utils.py index 490361c17..6ec795ac3 100644 --- a/backend/tests/test_validation_utils.py +++ b/backend/tests/test_validation_utils.py @@ -1,9 +1,7 @@ import unittest from copy import deepcopy -from base_utils.base_utils import obtain_field_location from jsonpath_ng.ext import parse -from models.field_locations import FieldLocations from models.obtain_field_value import ObtainFieldValue from models.utils.generic_utils import ( get_current_name_instance, @@ -13,10 +11,10 @@ ) from models.fhir_immunization import ImmunizationValidator -from utils.generic_utils import ( +from testing_utils.generic_utils import ( load_json_data, ) -from utils.values_for_tests import ValidValues, InvalidValues, NameInstances +from testing_utils.values_for_tests import ValidValues, InvalidValues, NameInstances class TestValidatorUtils(unittest.TestCase): diff --git a/backend/tests/testing_utils/__init__.py b/backend/tests/testing_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/utils/generic_utils.py b/backend/tests/testing_utils/generic_utils.py similarity index 100% rename from backend/tests/utils/generic_utils.py rename to backend/tests/testing_utils/generic_utils.py diff --git a/backend/tests/utils/immunization_utils.py b/backend/tests/testing_utils/immunization_utils.py similarity index 92% rename from backend/tests/utils/immunization_utils.py rename to backend/tests/testing_utils/immunization_utils.py index 488e9a917..1704b6165 100644 --- a/backend/tests/utils/immunization_utils.py +++ b/backend/tests/testing_utils/immunization_utils.py @@ -2,8 +2,8 @@ from fhir.resources.R4B.immunization import Immunization -from tests.utils.values_for_tests import ValidValues -from tests.utils.generic_utils import load_json_data +from testing_utils.values_for_tests import ValidValues +from testing_utils.generic_utils import load_json_data VALID_NHS_NUMBER = ValidValues.nhs_number diff --git a/backend/tests/utils/mandation_test_utils.py b/backend/tests/testing_utils/mandation_test_utils.py similarity index 100% rename from backend/tests/utils/mandation_test_utils.py rename to backend/tests/testing_utils/mandation_test_utils.py diff --git a/backend/tests/utils/pre_validation_test_utils.py b/backend/tests/testing_utils/pre_validation_test_utils.py similarity index 100% rename from backend/tests/utils/pre_validation_test_utils.py rename to backend/tests/testing_utils/pre_validation_test_utils.py diff --git a/backend/tests/utils/test_utils_for_batch.py b/backend/tests/testing_utils/test_utils_for_batch.py similarity index 100% rename from backend/tests/utils/test_utils_for_batch.py rename to backend/tests/testing_utils/test_utils_for_batch.py diff --git a/backend/tests/utils/values_for_tests.py b/backend/tests/testing_utils/values_for_tests.py similarity index 99% rename from backend/tests/utils/values_for_tests.py rename to backend/tests/testing_utils/values_for_tests.py index 0403aa334..7ddec81af 100644 --- a/backend/tests/utils/values_for_tests.py +++ b/backend/tests/testing_utils/values_for_tests.py @@ -397,5 +397,5 @@ class InvalidValues: {"use": "official", "given": ["Florence"]}, {"family": "Nightingale", "given": ""}, ] - + invalid_dose_quantity = {"value": 2, "unit": "ml", "code": "258773002"} diff --git a/backend/tests/utils/test_dict_utils.py b/backend/tests/utils/test_dict_utils.py new file mode 100644 index 000000000..718fac210 --- /dev/null +++ b/backend/tests/utils/test_dict_utils.py @@ -0,0 +1,44 @@ +import unittest + +from utils import dict_utils + + +class TestDictUtils(unittest.TestCase): + + def test_get_field_returns_none_if_value_is_not_dict(self): + """Test that the default None value is returned if the provided argument is not a dict""" + result = dict_utils.get_field(["test"], "test_key") + + self.assertIsNone(result) + + def test_get_field_returns_default_value_if_key_not_in_dict(self): + """Test that the default None value is returned if the given key is not present in the dict""" + test_dict = {"test": "foo"} + + result = dict_utils.get_field(test_dict, "not_present") + + self.assertIsNone(result) + + def test_get_field_returns_value_from_nested_dict(self): + """Test that a value is retrieved from a nested dictionary""" + test_dict = {"a": {"b": {"c": 42}}} + + result = dict_utils.get_field(test_dict, "a", "b", "c") + + self.assertEqual(result, 42) + + def test_get_field_returns_a_dictionary_from_nested_dict(self): + """Test that where the value to retrieve is a dictionary then this is also successful""" + test_dict = {"a": {"b": {"c": {"foo": {"bar": "test"}}}}} + + result = dict_utils.get_field(test_dict, "a", "b", "c") + + self.assertDictEqual(result, {"foo": {"bar": "test"}}) + + def test_get_field_returns_override_default_value_when_provided(self): + """Test that when a key is not found and the user provides an override default value then this is returned""" + test_dict = {"a": {"test": "testing"}} + + result = dict_utils.get_field(test_dict, "a", "does_not_exist", default="") + + self.assertEqual(result, "") diff --git a/lambdas/ack_backend/tests/test_splunk_logging.py b/lambdas/ack_backend/tests/test_splunk_logging.py index 40edbe906..e5ddfb695 100644 --- a/lambdas/ack_backend/tests/test_splunk_logging.py +++ b/lambdas/ack_backend/tests/test_splunk_logging.py @@ -6,7 +6,6 @@ from contextlib import ExitStack from moto import mock_s3 from boto3 import client as boto3_client -from common.log_decorator import generate_and_send_logs, send_log_to_firehose from tests.utils.values_for_ack_backend_tests import ( ValidValues,