Skip to content

Commit c239cae

Browse files
committed
Refactored a few more endpoints to use consolidated auth
1 parent f06593a commit c239cae

File tree

8 files changed

+121
-144
lines changed

8 files changed

+121
-144
lines changed
File renamed without changes.
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from authorisation.ApiOperationCode import ApiOperationCode
3+
from authorisation.api_operation_code import ApiOperationCode
44
from clients import redis_client, logger
55
from constants import SUPPLIER_PERMISSIONS_HASH_KEY
66

@@ -27,19 +27,19 @@ def _expand_permissions(permissions: list[str]) -> dict[str, list[ApiOperationCo
2727

2828
return expanded_permissions
2929

30-
def _get_supplier_permissions(self, supplier_name: str) -> dict[str, list[ApiOperationCode]]:
31-
raw_permissions_data = self._cache_client.hget(SUPPLIER_PERMISSIONS_HASH_KEY, supplier_name)
30+
def _get_supplier_permissions(self, supplier_system: str) -> dict[str, list[ApiOperationCode]]:
31+
raw_permissions_data = self._cache_client.hget(SUPPLIER_PERMISSIONS_HASH_KEY, supplier_system)
3232
permissions_data = json.loads(raw_permissions_data) if raw_permissions_data else []
3333

3434
return self._expand_permissions(permissions_data)
3535

3636
def authorise(
3737
self,
38-
supplier_name: str,
38+
supplier_system: str,
3939
requested_operation: ApiOperationCode,
4040
vaccination_types: set[str]
4141
) -> bool:
42-
supplier_permissions = self._get_supplier_permissions(supplier_name)
42+
supplier_permissions = self._get_supplier_permissions(supplier_system)
4343

4444
logger.info(
4545
f"operation: {requested_operation}, supplier_permissions: {supplier_permissions}, "

backend/src/fhir_controller.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_immunization_by_id(self, aws_event) -> dict:
125125
return self.create_response(403, unauthorized.to_operation_outcome())
126126

127127
try:
128-
if resource := self.fhir_service.get_immunization_by_id(imms_id, imms_vax_type_perms):
128+
if resource := self.fhir_service.get_immunization_by_id(imms_id, supplier_system):
129129
version = str()
130130
if isinstance(resource, Immunization):
131131
resp = resource
@@ -157,19 +157,19 @@ def create_immunization(self, aws_event):
157157
return self.create_response(403, unauthorized.to_operation_outcome())
158158

159159
# Call the common method and unpack the results
160+
# TODO - can remove this and the block above. Only need supplier system
160161
response, imms_vax_type_perms, supplier_system = self.check_vaccine_type_permissions(
161162
aws_event
162163
)
163164
if response:
164165
return response
165166

166167
try:
167-
imms = json.loads(aws_event["body"], parse_float=Decimal)
168+
immunisation = json.loads(aws_event["body"], parse_float=Decimal)
168169
except json.decoder.JSONDecodeError as e:
169170
return self._create_bad_request(f"Request's body contains malformed JSON: {e}")
170171
try:
171-
resource = self.fhir_service.create_immunization(
172-
imms, imms_vax_type_perms, supplier_system)
172+
resource = self.fhir_service.create_immunization(immunisation, supplier_system)
173173
if "diagnostics" in resource:
174174
exp_error = create_operation_outcome(
175175
resource_id=str(uuid.uuid4()),
@@ -366,10 +366,9 @@ def delete_immunization(self, aws_event):
366366
except UnauthorizedError as unauthorized:
367367
return self.create_response(403, unauthorized.to_operation_outcome())
368368

369-
# Validate the imms id - start
369+
# Validate the imms id
370370
if id_error := self._validate_id(imms_id):
371371
return FhirController.create_response(400, json.dumps(id_error))
372-
# Validate the imms id - end
373372

374373
# Call the common method and unpack the results
375374
response, imms_vax_type_perms, supplier_system = self.check_vaccine_type_permissions(

backend/src/fhir_repository.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_immunization_by_identifier(self, identifier_pk: str) -> tuple[Optional[d
102102
else:
103103
return None, None
104104

105-
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: str) -> Optional[dict]:
105+
def get_immunization_by_id(self, imms_id: str) -> Optional[dict]:
106106
response = self.table.get_item(Key={"PK": _make_immunization_pk(imms_id)})
107107
item = response.get("Item")
108108

@@ -111,12 +111,6 @@ def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: str) -> Opti
111111
if item.get("DeletedAt") and item["DeletedAt"] != "reinstated":
112112
return None
113113

114-
# Get vaccine type + validate permissions
115-
vaccine_type = self._vaccine_type(item["PatientSK"])
116-
if not validate_permissions(imms_vax_type_perms, ApiOperationCode.READ, [vaccine_type]):
117-
raise UnauthorizedVaxError()
118-
119-
# Build response
120114
return {
121115
"Resource": json.loads(item["Resource"]),
122116
"Version": item["Version"]
@@ -155,14 +149,11 @@ def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]
155149
else:
156150
return None
157151

158-
def create_immunization(
159-
self, immunization: dict, patient: any, imms_vax_type_perms, supplier_system
160-
) -> dict:
152+
def create_immunization(self, immunization: dict, patient: any, supplier_system: str) -> dict:
161153
new_id = str(uuid.uuid4())
162154
immunization["id"] = new_id
163155
attr = RecordAttributes(immunization, patient)
164-
if not validate_permissions(imms_vax_type_perms,ApiOperationCode.CREATE, [attr.vaccine_type]):
165-
raise UnauthorizedVaxError()
156+
166157
query_response = _query_identifier(self.table, "IdentifierGSI", "IdentifierPK", attr.identifier)
167158

168159
if query_response is not None:

backend/src/fhir_service.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
BundleEntrySearch,
1313
)
1414
from fhir.resources.R4B.immunization import Immunization
15-
from poetry.console.commands import self
1615
from pydantic import ValidationError
1716

1817
import parameter_parser
19-
from authorisation.ApiOperationCode import ApiOperationCode
20-
from authorisation.Authoriser import Authoriser
18+
from authorisation.api_operation_code import ApiOperationCode
19+
from authorisation.authoriser import Authoriser
2120
from fhir_repository import ImmunizationRepository
2221
from models.errors import InvalidPatientId, CustomValidationError, UnauthorizedVaxError
2322
from models.fhir_immunization import ImmunizationValidator
2423
from models.utils.generic_utils import nhs_number_mod11_check, get_occurrence_datetime, create_diagnostics, form_json, get_contained_patient
2524
from models.errors import MandatoryError
25+
from models.utils.validation_utils import get_vaccine_type
2626
from timer import timed
2727
from filter import Filter
2828

@@ -77,16 +77,20 @@ def get_immunization_by_identifier(
7777

7878
return form_json(imms_resp, element, identifier, base_url)
7979

80-
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: list[str]) -> Optional[dict]:
80+
def get_immunization_by_id(self, imms_id: str, supplier_system: str) -> Optional[dict]:
8181
"""
8282
Get an Immunization by its ID. Return None if it is not found. If the patient doesn't have an NHS number,
8383
return the Immunization without calling PDS or checking S flag.
8484
"""
85-
if not (imms_resp := self.immunization_repo.get_immunization_by_id(imms_id, imms_vax_type_perms)):
85+
if not (imms_resp := self.immunization_repo.get_immunization_by_id(imms_id)):
8686
return None
8787

8888
# Returns the Immunisation full resource with no obfuscation
8989
resource = imms_resp.get("Resource", {})
90+
vaccination_type = get_vaccine_type(resource)
91+
92+
if not self.authoriser.authorise(supplier_system, ApiOperationCode.READ, {vaccination_type}):
93+
raise UnauthorizedVaxError()
9094

9195
return {
9296
"Version": imms_resp.get("Version", ""),
@@ -106,10 +110,7 @@ def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]
106110
imms_resp = self.immunization_repo.get_immunization_by_id_all(imms_id, imms)
107111
return imms_resp
108112

109-
def create_immunization(
110-
self, immunization: dict, imms_vax_type_perms, supplier_system
111-
) -> Immunization:
112-
113+
def create_immunization(self, immunization: dict, supplier_system: str) -> dict | Immunization:
113114
if immunization.get("id") is not None:
114115
raise CustomValidationError("id field must not be present for CREATE operation")
115116

@@ -118,14 +119,17 @@ def create_immunization(
118119
except (ValidationError, ValueError, MandatoryError) as error:
119120
raise CustomValidationError(message=str(error)) from error
120121
patient = self._validate_patient(immunization)
122+
121123
if "diagnostics" in patient:
122124
return patient
123125

124-
imms = self.immunization_repo.create_immunization(
125-
immunization, patient, imms_vax_type_perms, supplier_system
126-
)
126+
vaccination_type = get_vaccine_type(immunization)
127127

128-
return Immunization.parse_obj(imms)
128+
if not self.authoriser.authorise(supplier_system, ApiOperationCode.CREATE, {vaccination_type}):
129+
raise UnauthorizedVaxError()
130+
131+
immunisation = self.immunization_repo.create_immunization(immunization, patient, supplier_system)
132+
return Immunization.parse_obj(immunisation)
129133

130134
def update_immunization(
131135
self,
@@ -197,7 +201,7 @@ def update_reinstated_immunization(
197201

198202
return UpdateOutcome.UPDATE, Immunization.parse_obj(imms), updated_version
199203

200-
def delete_immunization(self, imms_id, imms_vax_type_perms, supplier_system) -> Immunization:
204+
def delete_immunization(self, imms_id: str, imms_vax_type_perms, supplier_system: str) -> Immunization:
201205
"""
202206
Delete an Immunization if it exits and return the ID back if successful.
203207
Exception will be raised if resource didn't exit. Multiple calls to this method won't change

backend/tests/test_fhir_controller.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
import unittest
66
import uuid
77

8-
from unittest.mock import patch
98
from fhir.resources.R4B.bundle import Bundle
109
from fhir.resources.R4B.immunization import Immunization
1110
from unittest.mock import create_autospec, ANY, patch, Mock
1211
from urllib.parse import urlencode
1312
import urllib.parse
14-
from moto import mock_aws
1513
from authorization import Authorization
1614
from fhir_controller import FhirController
1715
from fhir_repository import ImmunizationRepository
@@ -22,15 +20,12 @@
2220
InvalidPatientId,
2321
CustomValidationError,
2422
ParameterException,
25-
InconsistentIdError,
2623
UnauthorizedVaxError,
27-
UnauthorizedError,
2824
IdentifierDuplicationError,
2925
)
3026
from tests.utils.immunization_utils import create_covid_19_immunization
3127
from parameter_parser import patient_identifier_system, process_search_params
3228
from tests.utils.generic_utils import load_json_data
33-
from tests.utils.values_for_tests import ValidValues
3429

3530
class TestFhirControllerBase(unittest.TestCase):
3631
"""Base class for all tests to set up common fixtures"""
@@ -211,7 +206,7 @@ def test_get_imms_by_identifer_patient_identifier_and_element_present(self, mock
211206
# When
212207
response = self.controller.get_immunization_by_identifier(lambda_event)
213208
# Then
214-
self.service.get_immunization_by_identifier.assert_not_called
209+
self.service.get_immunization_by_identifier.assert_not_called()
215210

216211
self.assertEqual(response["statusCode"], 400)
217212
body = json.loads(response["body"])
@@ -234,7 +229,7 @@ def test_get_imms_by_identifer_both_body_and_query_params_present(self, mock_get
234229
# When
235230
response = self.controller.get_immunization_by_identifier(lambda_event)
236231
# Then
237-
self.service.get_immunization_by_identifier.assert_not_called
232+
self.service.get_immunization_by_identifier.assert_not_called()
238233

239234
self.assertEqual(response["statusCode"], 400)
240235
body = json.loads(response["body"])
@@ -257,7 +252,7 @@ def test_get_imms_by_identifer_imms_identifier_and_element_not_present(self, moc
257252
# When
258253
response = self.controller.get_immunization_by_identifier(lambda_event)
259254
# Then
260-
self.service.get_immunization_by_identifier.assert_not_called
255+
self.service.get_immunization_by_identifier.assert_not_called()
261256

262257
self.assertEqual(response["statusCode"], 400)
263258
body = json.loads(response["body"])
@@ -281,7 +276,7 @@ def test_get_imms_by_identifer_both_identifier_present(self, mock_get_supplier_p
281276
# When
282277
response = self.controller.get_immunization_by_identifier(lambda_event)
283278
# Then
284-
self.service.get_immunization_by_identifier.assert_not_called
279+
self.service.get_immunization_by_identifier.assert_not_called()
285280

286281
self.assertEqual(response["statusCode"], 400)
287282
body = json.loads(response["body"])
@@ -565,7 +560,7 @@ def test_get_imms_by_identifer_patient_identifier_and_element_present(self, mock
565560
# When
566561
response = self.controller.get_immunization_by_identifier(lambda_event)
567562
# Then
568-
self.service.get_immunization_by_identifier.assert_not_called
563+
self.service.get_immunization_by_identifier.assert_not_called()
569564

570565
self.assertEqual(response["statusCode"], 400)
571566
body = json.loads(response["body"])
@@ -585,7 +580,7 @@ def test_get_imms_by_identifer_imms_identifier_and_element_not_present(self,mock
585580
# When
586581
response = self.controller.get_immunization_by_identifier(lambda_event)
587582
# Then
588-
self.service.get_immunization_by_identifier.assert_not_called
583+
self.service.get_immunization_by_identifier.assert_not_called()
589584

590585
self.assertEqual(response["statusCode"], 400)
591586
body = json.loads(response["body"])
@@ -669,7 +664,7 @@ def test_get_imms_by_identifer_both_identifier_present(self, mock_get_supplier_p
669664
# When
670665
response = self.controller.get_immunization_by_identifier(lambda_event)
671666
# Then
672-
self.service.get_immunization_by_identifier.assert_not_called
667+
self.service.get_immunization_by_identifier.assert_not_called()
673668

674669
self.assertEqual(response["statusCode"], 400)
675670
body = json.loads(response["body"])
@@ -810,7 +805,7 @@ def test_get_imms_by_id(self, mock_permissions):
810805
response = self.controller.get_immunization_by_id(lambda_event)
811806
# Then
812807
mock_permissions.assert_called_once_with("test")
813-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
808+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test")
814809

815810
self.assertEqual(response["statusCode"], 200)
816811
body = json.loads(response["body"])
@@ -833,6 +828,7 @@ def test_get_imms_by_id_unauthorised_vax_error(self,mock_permissions):
833828
# Then
834829
mock_permissions.assert_called_once_with("test")
835830
self.assertEqual(response["statusCode"], 403)
831+
836832
@patch("fhir_controller.get_supplier_permissions")
837833
def test_get_imms_by_id_no_vax_permission(self, mock_permissions):
838834
"""it should return Immunization Id if it exists"""
@@ -867,7 +863,7 @@ def test_not_found(self,mock_permissions):
867863

868864
# Then
869865
mock_permissions.assert_called_once_with("test")
870-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
866+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test")
871867

872868
self.assertEqual(response["statusCode"], 404)
873869
body = json.loads(response["body"])
@@ -913,7 +909,7 @@ def test_create_immunization(self,mock_get_permissions):
913909

914910
imms_obj = json.loads(aws_event["body"])
915911
mock_get_permissions.assert_called_once_with("Test")
916-
self.service.create_immunization.assert_called_once_with(imms_obj, ["COVID19.CRUDS", "FLU.CRUDS"], "Test")
912+
self.service.create_immunization.assert_called_once_with(imms_obj, "Test")
917913
self.assertEqual(response["statusCode"], 201)
918914
self.assertTrue("body" not in response)
919915
self.assertTrue(response["headers"]["Location"].endswith(f"Immunization/{imms_id}"))
@@ -1324,7 +1320,7 @@ def test_update_record_exists(self, mock_get_supplier_permissions):
13241320
response = self.controller.get_immunization_by_id(lambda_event)
13251321

13261322
# Then
1327-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
1323+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "Test")
13281324

13291325
self.assertEqual(response["statusCode"], 404)
13301326
body = json.loads(response["body"])

0 commit comments

Comments
 (0)