Skip to content

Commit 271c6a5

Browse files
committed
Create basic authoriser class
1 parent 4b1f50d commit 271c6a5

File tree

10 files changed

+103
-51
lines changed

10 files changed

+103
-51
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import StrEnum
2+
3+
4+
class ApiOperationCode(StrEnum):
5+
CREATE = "c"
6+
READ = "r"
7+
UPDATE = "u"
8+
DELETE = "d"
9+
SEARCH = "s"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import json
2+
3+
from authorisation.ApiOperationCode import ApiOperationCode
4+
from clients import redis_client, logger
5+
from constants import SUPPLIER_PERMISSIONS_HASH_KEY
6+
7+
8+
class Authoriser:
9+
def __init__(self):
10+
self._cache_client = redis_client
11+
12+
@staticmethod
13+
def _expand_permissions(permissions: list[str]) -> dict[str, list[ApiOperationCode]]:
14+
"""Parses and expands permissions data into a dictionary mapping vaccination types to a list of permitted
15+
API operations. The raw string from Redis will be in the form VAC.PERMS e.g. COVID19.CRUDS"""
16+
expanded_permissions = {}
17+
18+
for permission in permissions:
19+
vaccine_type, operation_codes_str = permission.split(".", maxsplit=1)
20+
vaccine_type = vaccine_type.lower()
21+
operation_codes = [
22+
operation_code
23+
for operation_code in operation_codes_str.lower()
24+
if operation_code in list(ApiOperationCode)
25+
]
26+
expanded_permissions[vaccine_type] = operation_codes
27+
28+
return expanded_permissions
29+
30+
def _get_supplier_permissions(self, supplier_name: str) -> dict[str, list[ApiOperationCode]]:
31+
raw_permissions_data = self._cache_client.hget(SUPPLIER_PERMISSIONS_HASH_KEY, supplier_name)
32+
permissions_data = json.loads(raw_permissions_data) if raw_permissions_data else []
33+
34+
return self._expand_permissions(permissions_data)
35+
36+
def authorise(
37+
self,
38+
supplier_name: str,
39+
requested_operation: ApiOperationCode,
40+
vaccination_types: set[str]
41+
) -> bool:
42+
supplier_permissions = self._get_supplier_permissions(supplier_name)
43+
44+
logger.info(
45+
f"operation: {requested_operation}, supplier_permissions: {supplier_permissions}, "
46+
f"vaccine_types: {vaccination_types}"
47+
)
48+
return all(
49+
requested_operation in supplier_permissions.get(vaccination_type.lower(), [])
50+
for vaccination_type in vaccination_types
51+
)

backend/src/authorisation/__init__.py

Whitespace-only changes.

backend/src/authorization.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from dataclasses import dataclass
22
from enum import Enum
33
from functools import wraps
4-
from typing import Set
54

65
from models.errors import UnauthorizedError
76

87

9-
AUTHENTICATION_HEADER = "AuthenticationType"
8+
AUTHENTICATION_TYPE_HEADER_NAME = "AuthenticationType"
109

1110

1211
@dataclass
@@ -32,17 +31,17 @@ class Authorization:
3231
UnknownPermission is due to proxy bad configuration, and should result in 500. Any invalid value, either
3332
insufficient permissions or bad string, will result in UnauthorizedError if it comes from user.
3433
"""
35-
34+
3635
def authorize(self, aws_event: dict):
3736
auth_type = self._parse_auth_type(aws_event["headers"])
38-
37+
3938
if auth_type not in {AuthType.APP_RESTRICTED, AuthType.CIS2, AuthType.NHS_LOGIN}:
4039
raise UnauthorizedError()
4140

4241
@staticmethod
43-
def _parse_auth_type(headers) -> AuthType:
42+
def _parse_auth_type(headers: dict) -> AuthType:
4443
try:
45-
auth_type = headers[AUTHENTICATION_HEADER]
44+
auth_type = headers[AUTHENTICATION_TYPE_HEADER_NAME]
4645
return AuthType(auth_type)
4746
except ValueError:
4847
# The value of authentication type comes from apigee regardless of auth type. That's why

backend/src/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ class Urls:
2323

2424

2525
GENERIC_SERVER_ERROR_DIAGNOSTICS_MESSAGE = "Unable to process request. Issue may be transient."
26+
SUPPLIER_PERMISSIONS_HASH_KEY = "supplier_permissions"

backend/src/fhir_controller.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import base64
2-
import boto3
32
import json
43
import os
54
import re
65
import uuid
76
from decimal import Decimal
87
from typing import Optional
9-
import boto3
108
from aws_lambda_typing.events import APIGatewayProxyEventV1
119
from fhir.resources.R4B.immunization import Immunization
1210
from boto3 import client as boto3_client
@@ -25,7 +23,6 @@
2523
ValidationError,
2624
IdentifierDuplicationError,
2725
ParameterException,
28-
InconsistentIdError,
2926
UnauthorizedVaxError,
3027
UnauthorizedVaxOnRecordError,
3128
UnauthorizedSystemError,
@@ -101,7 +98,7 @@ def get_immunization_by_identifier(self, aws_event) -> dict:
10198

10299
try:
103100
if resource := self.fhir_service.get_immunization_by_identifier(
104-
identifiers, imms_vax_type_perms, identifier, element):
101+
identifiers, supplier_system, identifier, element):
105102
return FhirController.create_response(200, resource)
106103
except UnauthorizedVaxError as unauthorized:
107104
return self.create_response(403, unauthorized.to_operation_outcome())

backend/src/fhir_repository.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from urllib import response
21
from responses import logger
32
import simplejson as json
43
import os
@@ -87,24 +86,21 @@ class ImmunizationRepository:
8786
def __init__(self, table: Table):
8887
self.table = table
8988

90-
def get_immunization_by_identifier(
91-
self, identifier_pk: str, imms_vax_type_perms: list[str]
92-
) -> Optional[dict]:
89+
def get_immunization_by_identifier(self, identifier_pk: str) -> tuple[Optional[dict], Optional[str]]:
9390
response = self.table.query(
9491
IndexName="IdentifierGSI", KeyConditionExpression=Key("IdentifierPK").eq(identifier_pk)
9592
)
93+
9694
if "Items" in response and len(response["Items"]) > 0:
9795
item = response["Items"][0]
9896
resp = dict()
9997
vaccine_type = self._vaccine_type(item["PatientSK"])
100-
if not validate_permissions(imms_vax_type_perms,ApiOperationCode.SEARCH, [vaccine_type]):
101-
raise UnauthorizedVaxError()
10298
resource = json.loads(item["Resource"])
10399
resp["id"] = resource.get("id")
104100
resp["version"] = int(response["Items"][0]["Version"])
105-
return resp
101+
return resp, vaccine_type
106102
else:
107-
return None
103+
return None, None
108104

109105
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: str) -> Optional[dict]:
110106
response = self.table.get_item(Key={"PK": _make_immunization_pk(imms_id)})
@@ -401,7 +397,7 @@ def find_immunizations(self, patient_identifier: str, vaccine_types: list):
401397

402398
raw_items = self.get_all_items(condition, is_not_deleted)
403399

404-
if raw_items:
400+
if raw_items:
405401
# Filter the response to contain only the requested vaccine types
406402
items = [x for x in raw_items if x["PatientSK"].split("#")[0] in vaccine_types]
407403

@@ -430,7 +426,7 @@ def get_all_items(self, condition, is_not_deleted):
430426
response = self.table.query(**query_args)
431427
if "Items" not in response:
432428
raise UnhandledResponseError(message="No Items in DynamoDB response", response=response)
433-
429+
434430
items = response.get("Items", [])
435431
all_items.extend(items)
436432

backend/src/fhir_service.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from pydantic import ValidationError
1616

1717
import parameter_parser
18+
from authorisation.ApiOperationCode import ApiOperationCode
19+
from authorisation.Authoriser import Authoriser
1820
from fhir_repository import ImmunizationRepository
19-
from base_utils.base_utils import obtain_field_value
20-
from models.field_names import FieldNames
21-
from models.errors import InvalidPatientId, CustomValidationError, UnhandledResponseError
21+
from models.errors import InvalidPatientId, CustomValidationError, UnauthorizedVaxError
2222
from models.fhir_immunization import ImmunizationValidator
2323
from models.utils.generic_utils import nhs_number_mod11_check, get_occurrence_datetime, create_diagnostics, form_json, get_contained_patient
24-
from models.constants import Constants
2524
from models.errors import MandatoryError
2625
from timer import timed
2726
from filter import Filter
@@ -54,27 +53,27 @@ def __init__(
5453
imms_repo: ImmunizationRepository,
5554
validator: ImmunizationValidator = ImmunizationValidator(),
5655
):
56+
self.authoriser = Authoriser()
5757
self.immunization_repo = imms_repo
5858
self.validator = validator
5959

6060
def get_immunization_by_identifier(
61-
self, identifier_pk: str, imms_vax_type_perms: list[str], identifier: str, element: str
61+
self, identifier_pk: str, supplier_name: str, identifier: str, element: str
6262
) -> Optional[dict]:
6363
"""
6464
Get an Immunization by its ID. Return None if not found. If the patient doesn't have an NHS number,
6565
return the Immunization without calling PDS or checking S flag.
6666
"""
67-
imms_resp = self.immunization_repo.get_immunization_by_identifier(
68-
identifier_pk, imms_vax_type_perms
69-
)
67+
base_url = f"{get_service_url()}/Immunization"
68+
imms_resp, vaccination_type = self.immunization_repo.get_immunization_by_identifier(identifier_pk)
69+
7070
if not imms_resp:
71-
base_url = f"{get_service_url()}/Immunization"
72-
response = form_json(imms_resp, None, None, base_url)
73-
return response
74-
else:
75-
base_url = f"{get_service_url()}/Immunization"
76-
response = form_json(imms_resp, element, identifier, base_url)
77-
return response
71+
return form_json(imms_resp, None, None, base_url)
72+
73+
if not self.authoriser.authorise(supplier_name, ApiOperationCode.SEARCH, {vaccination_type}):
74+
raise UnauthorizedVaxError
75+
76+
return form_json(imms_resp, element, identifier, base_url)
7877

7978
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: list[str]) -> Optional[dict]:
8079
"""

backend/tests/test_authorization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from authorization import (
33
Authorization,
44
UnknownPermission,
5-
AUTHENTICATION_HEADER,
5+
AUTHENTICATION_TYPE_HEADER_NAME,
66
authorize,
77
AuthType
88
)
@@ -12,7 +12,7 @@
1212
def _make_aws_event(auth_type: AuthType):
1313
return {
1414
"headers": {
15-
AUTHENTICATION_HEADER: auth_type.value
15+
AUTHENTICATION_TYPE_HEADER_NAME: auth_type.value
1616
}
1717
}
1818

@@ -35,7 +35,7 @@ def test_decorator_unauthorized(self):
3535
controller = TestAuthorizeDecorator.StubController()
3636
aws_event = {
3737
"headers": {
38-
AUTHENTICATION_HEADER: "InvalidType"
38+
AUTHENTICATION_TYPE_HEADER_NAME: "InvalidType"
3939
}
4040
}
4141
with self.assertRaises(UnknownPermission):
@@ -59,7 +59,7 @@ def test_valid_auth_types(self):
5959
def test_unknown_authorization(self):
6060
aws_event = {
6161
"headers": {
62-
AUTHENTICATION_HEADER: "unknown auth type"
62+
AUTHENTICATION_TYPE_HEADER_NAME: "unknown auth type"
6363
}
6464
}
6565
with self.assertRaises(UnknownPermission):
@@ -80,7 +80,7 @@ def test_valid_app_restricted_auth(self):
8080
def test_invalid_auth_type_raises_unknown_permission(self):
8181
aws_event = {
8282
"headers": {
83-
AUTHENTICATION_HEADER: "InvalidAuthType"
83+
AUTHENTICATION_TYPE_HEADER_NAME: "InvalidAuthType"
8484
}
8585
}
8686
with self.assertRaises(UnknownPermission):
@@ -101,8 +101,8 @@ def test_valid_cis2_auth(self):
101101
def test_invalid_auth_type_raises_unknown_permission(self):
102102
aws_event = {
103103
"headers": {
104-
AUTHENTICATION_HEADER: "InvalidAuthType"
104+
AUTHENTICATION_TYPE_HEADER_NAME: "InvalidAuthType"
105105
}
106106
}
107107
with self.assertRaises(UnknownPermission):
108-
self.authorization.authorize(aws_event)
108+
self.authorization.authorize(aws_event)

backend/tests/test_fhir_controller_authorization.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Authorization,
1010
UnknownPermission,
1111
AuthType,
12-
AUTHENTICATION_HEADER,
12+
AUTHENTICATION_TYPE_HEADER_NAME,
1313
)
1414
from fhir_controller import FhirController
1515
from fhir_repository import ImmunizationRepository
@@ -20,7 +20,7 @@
2020

2121

2222
def make_aws_event(auth_type: AuthType, permissions=None) -> dict:
23-
return {"headers": {AUTHENTICATION_HEADER: str(auth_type)}}
23+
return {"headers": {AUTHENTICATION_TYPE_HEADER_NAME: str(auth_type)}}
2424

2525

2626
class TestFhirControllerAuthorization(unittest.TestCase):
@@ -36,10 +36,10 @@ def setUp(self):
3636
self.controller = FhirController(self.authorizer, self.service)
3737
self.logger_info_patcher = patch("logging.Logger.info")
3838
self.mock_logger_info = self.logger_info_patcher.start()
39-
39+
4040
def tearDown(self):
4141
patch.stopall()
42-
42+
4343
def test_get_imms_by_id_authorized(self):
4444
aws_event = {"pathParameters": {"id": "an-id"}}
4545

@@ -111,20 +111,20 @@ def test_update_imms_authorized(self, mock_get_supplier_permissions):
111111
_ = self.controller.update_immunization(aws_event)
112112

113113
self.authorizer.authorize.assert_called_once_with(aws_event)
114-
114+
115115
@patch("fhir_controller.get_supplier_permissions")
116116
def test_update_imms_unauthorized_vaxx_in_record(self,mock_get_supplier_permissions):
117117
mock_get_supplier_permissions.return_value = ["Covid19.CRUDS"]
118118
imms_id = str(uuid.uuid4())
119119
aws_event = {"headers": {"E-Tag":1, "SupplierSystem" : "Test"},"pathParameters": {"id": imms_id}, "body": create_covid_19_immunization(imms_id).json()}
120120
self.service.get_immunization_by_id_all.return_value = {"resource":"new_value","Version":1,"DeletedAt": False, "VaccineType":"Flu"}
121-
121+
122122
response = self.controller.update_immunization(aws_event)
123123
self.assertEqual(response["statusCode"], 403)
124124
body = json.loads(response["body"])
125125
self.assertEqual(body["resourceType"], "OperationOutcome")
126-
self.assertEqual(body["issue"][0]["code"], "forbidden")
127-
126+
self.assertEqual(body["issue"][0]["code"], "forbidden")
127+
128128
self.authorizer.authorize.assert_called_once_with(aws_event)
129129

130130
def test_update_imms_unauthorized(self):
@@ -209,4 +209,4 @@ def test_search_imms_unknown_permission(self):
209209
self.assertEqual(response["statusCode"], 500)
210210
body = json.loads(response["body"])
211211
self.assertEqual(body["resourceType"], "OperationOutcome")
212-
self.assertEqual(body["issue"][0]["code"], "exception")
212+
self.assertEqual(body["issue"][0]["code"], "exception")

0 commit comments

Comments
 (0)