diff --git a/lambdas/enums/environment.py b/lambdas/enums/environment.py new file mode 100644 index 000000000..ada6bb860 --- /dev/null +++ b/lambdas/enums/environment.py @@ -0,0 +1,17 @@ +import os +from enum import Enum + + +class Environment(str, Enum): + PROD = "prod" + PRE_PROD = "pre-prod" + NDR_TEST = "ndr-test" + NDR_DEV = "ndr-dev" + + @classmethod + def from_env(cls) -> "Environment": + value = os.getenv("WORKSPACE") + if not value: + return cls.NDR_DEV + + return cls._value2member_map_.get(value.lower(), cls.NDR_DEV) diff --git a/lambdas/enums/mtls.py b/lambdas/enums/mtls.py index e26d9bc7f..37c3ba1a5 100644 --- a/lambdas/enums/mtls.py +++ b/lambdas/enums/mtls.py @@ -1,6 +1,11 @@ +import boto3 +import json + from enum import StrEnum, auto +from functools import lru_cache from enums.lambda_error import LambdaError +from enums.environment import Environment from utils.audit_logging_setup import LoggingService from utils.lambda_exceptions import InvalidDocTypeException @@ -12,13 +17,8 @@ class MtlsCommonNames(StrEnum): @classmethod def allowed_names(cls) -> dict["MtlsCommonNames", list[str]]: - return { - cls.PDM: [ - "ndrclient.main.int.pdm.national.nhs.uk", - "client.dev.ndr.national.nhs.uk", - "client.preprod.ndr.national.nhs.uk" - ] - } + raw = cls._get_mtls_common_names() + return {cls[k]: v for k, v in raw.items() if k in cls.__members__} @classmethod def from_common_name(cls, common_name: str) -> "MtlsCommonNames | None": @@ -27,3 +27,14 @@ def from_common_name(cls, common_name: str) -> "MtlsCommonNames | None": return doc_type logger.error(f"mTLS common name {common_name} - is not supported") raise InvalidDocTypeException(400, LambdaError.DocTypeInvalid) + + @classmethod + @lru_cache(maxsize=1) + def _get_mtls_common_names(cls) -> dict[str, list[str]]: + ssm = boto3.client("ssm") + environment = Environment.from_env().value + response = ssm.get_parameter( + Name=f"/ndr/{environment}/mtls_common_names", + WithDecryption=True, + ) + return json.loads(response["Parameter"]["Value"]) diff --git a/lambdas/tests/unit/enums/test_environment.py b/lambdas/tests/unit/enums/test_environment.py new file mode 100644 index 000000000..467c646df --- /dev/null +++ b/lambdas/tests/unit/enums/test_environment.py @@ -0,0 +1,46 @@ +import pytest + +from enums.environment import Environment + + +@pytest.mark.parametrize( + "env_value, expected", + [ + ("prod", Environment.PROD), + ("pre-prod", Environment.PRE_PROD), + ("ndr-test", Environment.NDR_TEST), + ("ndr-dev", Environment.NDR_DEV), + ], +) +def test_valid_workspace_values(monkeypatch, env_value, expected): + monkeypatch.setenv("WORKSPACE", env_value) + assert Environment.from_env() == expected + + +@pytest.mark.parametrize( + "env_value", + [ + "abcd1", + "ndr000", + "prmp000", + "foobar", + ], +) +def test_invalid_workspace_defaults_to_ndr_dev(monkeypatch, env_value): + monkeypatch.setenv("WORKSPACE", env_value) + assert Environment.from_env() == Environment.NDR_DEV + + +def test_workspace_is_case_insensitive(monkeypatch): + monkeypatch.setenv("WORKSPACE", "PRE-PROD") + assert Environment.from_env() == Environment.PRE_PROD + + +def test_workspace_not_set_defaults_to_ndr_dev(monkeypatch): + monkeypatch.delenv("WORKSPACE", raising=False) + assert Environment.from_env() == Environment.NDR_DEV + + +def test_workspace_empty_string_defaults_to_ndr_dev(monkeypatch): + monkeypatch.setenv("WORKSPACE", "") + assert Environment.from_env() == Environment.NDR_DEV diff --git a/lambdas/tests/unit/enums/test_mtls.py b/lambdas/tests/unit/enums/test_mtls.py index cc0c56566..45dd73b1a 100644 --- a/lambdas/tests/unit/enums/test_mtls.py +++ b/lambdas/tests/unit/enums/test_mtls.py @@ -7,11 +7,17 @@ @pytest.mark.parametrize( ["common_name", "expected"], [ - ("ndrclient.main.int.pdm.national.nhs.uk", MtlsCommonNames.PDM), - ("client.dev.ndr.national.nhs.uk", MtlsCommonNames.PDM), + ("xxx", MtlsCommonNames.PDM), + ("yyy", MtlsCommonNames.PDM), + ("zzz", MtlsCommonNames.PDM), ], ) -def test_mtls_enum_returned(common_name, expected): +def test_mtls_enum_returned(common_name, expected, monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod(lambda cls: {"PDM": ["xxx", "yyy", "zzz"]}), + ) doc_type_enum = MtlsCommonNames.from_common_name(common_name) assert doc_type_enum == expected @@ -24,7 +30,12 @@ def test_mtls_enum_returned(common_name, expected): "foo.bar", ], ) -def test_mtls_enum_error_raised(common_name): +def test_mtls_enum_error_raised(common_name, monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod(lambda cls: {"PDM": ["xxx", "yyy", "zzz"]}), + ) with pytest.raises(InvalidDocTypeException) as excinfo: MtlsCommonNames.from_common_name(common_name) assert excinfo.value.status_code == 400 diff --git a/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py b/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py index 01c2bb148..ad1521349 100644 --- a/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py +++ b/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py @@ -1,4 +1,5 @@ import pytest +from enums.mtls import MtlsCommonNames from enums.snomed_codes import SnomedCodes from handlers.get_fhir_document_reference_handler import ( extract_document_parameters, @@ -63,8 +64,18 @@ def mock_document_service(mocker): return mock_service_instance +@pytest.fixture +def mock_mtls_common_names(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod(lambda cls: {"PDM": ["ndrclient.main.int.pdm.national.nhs.uk"]}), + ) + + def test_lambda_handler_happy_path_with_mtls_pdm_login( set_env, + mock_mtls_common_names, mock_document_service, context, ): @@ -85,7 +96,7 @@ def test_lambda_handler_happy_path_with_mtls_pdm_login( ) -def test_extract_bearer_token_when_pdm(context): +def test_extract_bearer_token_when_pdm(context, mock_mtls_common_names): token = extract_bearer_token(MOCK_MTLS_VALID_EVENT, context) assert token is None diff --git a/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_search_service.py b/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_search_service.py index 8d78c40e5..3ec59484f 100644 --- a/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_search_service.py +++ b/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_search_service.py @@ -1,4 +1,5 @@ import pytest +from enums.mtls import MtlsCommonNames from enums.snomed_codes import SnomedCodes from freezegun import freeze_time from models.document_reference import DocumentReference @@ -44,6 +45,15 @@ def mock_filter_builder(mocker): return mock_filter +@pytest.fixture +def mock_mtls_common_names(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod(lambda cls: {"PDM": ["ndrclient.main.int.pdm.national.nhs.uk"]}), + ) + + @pytest.mark.parametrize( "common_name, expected", [ @@ -72,7 +82,9 @@ def mock_filter_builder(mocker): ({}, ["test_pdm_dynamoDB_table", "test_lg_dynamoDB_table"]), ], ) -def test_get_pdm_table(set_env, mock_document_service, common_name, expected): +def test_get_pdm_table( + set_env, mock_document_service, common_name, expected, mock_mtls_common_names +): cn = validate_common_name_in_mtls(common_name) tables = mock_document_service._get_table_names(cn) assert tables == expected diff --git a/lambdas/tests/unit/services/test_pdm_post_fhir_document_reference_service.py b/lambdas/tests/unit/services/test_pdm_post_fhir_document_reference_service.py index 101a1bbd5..ca143721c 100644 --- a/lambdas/tests/unit/services/test_pdm_post_fhir_document_reference_service.py +++ b/lambdas/tests/unit/services/test_pdm_post_fhir_document_reference_service.py @@ -60,6 +60,22 @@ def mock_fhir_doc_ref_base_service(mocker, setup_request_context): yield service +@pytest.fixture +def mock_mtls_common_names(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod( + lambda cls: { + "PDM": [ + "ndrclient.main.int.pdm.national.nhs.uk", + "client.dev.ndr.national.nhs.uk", + ] + } + ), + ) + + @pytest.fixture def setup_request_context(): request_context.authorization = { @@ -322,6 +338,7 @@ def test_get_dynamo_table_for_lloyd_george_doc_type( def test_process_mtls_fhir_document_reference_with_binary( mock_fhir_doc_ref_base_service, mock_post_fhir_doc_ref_service, + mock_mtls_common_names, valid_mtls_fhir_doc_with_binary, valid_mtls_request_context, ): diff --git a/lambdas/tests/unit/utils/test_lambda_handler_utils.py b/lambdas/tests/unit/utils/test_lambda_handler_utils.py index 0117c5260..d7f6b7ce2 100644 --- a/lambdas/tests/unit/utils/test_lambda_handler_utils.py +++ b/lambdas/tests/unit/utils/test_lambda_handler_utils.py @@ -1,5 +1,6 @@ import pytest from enums.lambda_error import LambdaError +from enums.mtls import MtlsCommonNames from enums.snomed_codes import SnomedCodes from tests.unit.conftest import TEST_UUID from utils.lambda_exceptions import ( @@ -74,6 +75,22 @@ ] +@pytest.fixture +def mock_mtls_common_names(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod( + lambda cls: { + "PDM": [ + "ndrclient.main.int.pdm.national.nhs.uk", + "client.dev.ndr.national.nhs.uk", + ] + } + ), + ) + + @pytest.mark.parametrize( "function_name, mock_event", [ @@ -88,7 +105,7 @@ def test_extract_bearer_token_happy_paths(context, function_name, mock_event): assert token == f"Bearer {TEST_UUID}" -def test_extract_bearer_token_when_pdm(context): +def test_extract_bearer_token_when_pdm(context, mock_mtls_common_names): token = extract_bearer_token(MOCK_MTLS_VALID_EVENT, context) assert token is None diff --git a/lambdas/tests/unit/utils/test_lambda_header_utils.py b/lambdas/tests/unit/utils/test_lambda_header_utils.py index 48ad31f8f..5b675701c 100644 --- a/lambdas/tests/unit/utils/test_lambda_header_utils.py +++ b/lambdas/tests/unit/utils/test_lambda_header_utils.py @@ -64,14 +64,32 @@ def invalid_mtls_request_context(): } -def test_validate_valid_common_name(valid_mtls_request_context): +@pytest.fixture +def mock_mtls_common_names(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod( + lambda cls: { + "PDM": [ + "ndrclient.main.int.pdm.national.nhs.uk", + "client.dev.ndr.national.nhs.uk", + ] + } + ), + ) + + +def test_validate_valid_common_name(valid_mtls_request_context, mock_mtls_common_names): """Test validate_common_name when mtls and pdm.""" result = validate_common_name_in_mtls(valid_mtls_request_context) assert result == MtlsCommonNames.PDM.value -def test_validate_invalid_common_name(invalid_mtls_request_context): +def test_validate_invalid_common_name( + invalid_mtls_request_context, mock_mtls_common_names +): """Test validate_common_name when mtls but not allowed.""" with pytest.raises(InvalidDocTypeException) as excinfo: validate_common_name_in_mtls(invalid_mtls_request_context)