Skip to content

Commit c4ac34f

Browse files
committed
Refactored a few more endpoints to use consolidated auth
1 parent 947349b commit c4ac34f

File tree

8 files changed

+142
-142
lines changed

8 files changed

+142
-142
lines changed
File renamed without changes.
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22

3-
from authorisation.ApiOperationCode import ApiOperationCode
3+
from authorisation.api_operation_code import ApiOperationCode
44
from clients import redis_client, logger
55
from constants import SUPPLIER_PERMISSIONS_HASH_KEY
66

@@ -27,19 +27,19 @@ def _expand_permissions(permissions: list[str]) -> dict[str, list[ApiOperationCo
2727

2828
return expanded_permissions
2929

30-
def _get_supplier_permissions(self, supplier_name: str) -> dict[str, list[ApiOperationCode]]:
31-
raw_permissions_data = self._cache_client.hget(SUPPLIER_PERMISSIONS_HASH_KEY, supplier_name)
30+
def _get_supplier_permissions(self, supplier_system: str) -> dict[str, list[ApiOperationCode]]:
31+
raw_permissions_data = self._cache_client.hget(SUPPLIER_PERMISSIONS_HASH_KEY, supplier_system)
3232
permissions_data = json.loads(raw_permissions_data) if raw_permissions_data else []
3333

3434
return self._expand_permissions(permissions_data)
3535

3636
def authorise(
3737
self,
38-
supplier_name: str,
38+
supplier_system: str,
3939
requested_operation: ApiOperationCode,
4040
vaccination_types: set[str]
4141
) -> bool:
42-
supplier_permissions = self._get_supplier_permissions(supplier_name)
42+
supplier_permissions = self._get_supplier_permissions(supplier_system)
4343

4444
logger.info(
4545
f"operation: {requested_operation}, supplier_permissions: {supplier_permissions}, "

backend/src/fhir_controller.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_immunization_by_id(self, aws_event) -> dict:
125125
return self.create_response(403, unauthorized.to_operation_outcome())
126126

127127
try:
128-
if resource := self.fhir_service.get_immunization_by_id(imms_id, imms_vax_type_perms):
128+
if resource := self.fhir_service.get_immunization_by_id(imms_id, supplier_system):
129129
version = str()
130130
if isinstance(resource, Immunization):
131131
resp = resource
@@ -157,19 +157,19 @@ def create_immunization(self, aws_event):
157157
return self.create_response(403, unauthorized.to_operation_outcome())
158158

159159
# Call the common method and unpack the results
160+
# TODO - can remove this and the block above. Only need supplier system
160161
response, imms_vax_type_perms, supplier_system = self.check_vaccine_type_permissions(
161162
aws_event
162163
)
163164
if response:
164165
return response
165166

166167
try:
167-
imms = json.loads(aws_event["body"], parse_float=Decimal)
168+
immunisation = json.loads(aws_event["body"], parse_float=Decimal)
168169
except json.decoder.JSONDecodeError as e:
169170
return self._create_bad_request(f"Request's body contains malformed JSON: {e}")
170171
try:
171-
resource = self.fhir_service.create_immunization(
172-
imms, imms_vax_type_perms, supplier_system)
172+
resource = self.fhir_service.create_immunization(immunisation, supplier_system)
173173
if "diagnostics" in resource:
174174
exp_error = create_operation_outcome(
175175
resource_id=str(uuid.uuid4()),
@@ -366,10 +366,9 @@ def delete_immunization(self, aws_event):
366366
except UnauthorizedError as unauthorized:
367367
return self.create_response(403, unauthorized.to_operation_outcome())
368368

369-
# Validate the imms id - start
369+
# Validate the imms id
370370
if id_error := self._validate_id(imms_id):
371371
return FhirController.create_response(400, json.dumps(id_error))
372-
# Validate the imms id - end
373372

374373
# Call the common method and unpack the results
375374
response, imms_vax_type_perms, supplier_system = self.check_vaccine_type_permissions(

backend/src/fhir_repository.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_immunization_by_identifier(self, identifier_pk: str) -> tuple[Optional[d
104104
else:
105105
return None, None
106106

107-
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: str) -> Optional[dict]:
107+
def get_immunization_by_id(self, imms_id: str) -> Optional[dict]:
108108
response = self.table.get_item(Key={"PK": _make_immunization_pk(imms_id)})
109109
item = response.get("Item")
110110

@@ -113,12 +113,6 @@ def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: str) -> Opti
113113
if item.get("DeletedAt") and item["DeletedAt"] != "reinstated":
114114
return None
115115

116-
# Get vaccine type + validate permissions
117-
vaccine_type = self._vaccine_type(item["PatientSK"])
118-
if not validate_permissions(imms_vax_type_perms, ApiOperationCode.READ, [vaccine_type]):
119-
raise UnauthorizedVaxError()
120-
121-
# Build response
122116
return {
123117
"Resource": json.loads(item["Resource"]),
124118
"Version": item["Version"]
@@ -157,14 +151,11 @@ def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]
157151
else:
158152
return None
159153

160-
def create_immunization(
161-
self, immunization: dict, patient: any, imms_vax_type_perms, supplier_system
162-
) -> dict:
154+
def create_immunization(self, immunization: dict, patient: any, supplier_system: str) -> dict:
163155
new_id = str(uuid.uuid4())
164156
immunization["id"] = new_id
165157
attr = RecordAttributes(immunization, patient)
166-
if not validate_permissions(imms_vax_type_perms,ApiOperationCode.CREATE, [attr.vaccine_type]):
167-
raise UnauthorizedVaxError()
158+
168159
query_response = _query_identifier(self.table, "IdentifierGSI", "IdentifierPK", attr.identifier)
169160

170161
if query_response is not None:

backend/src/fhir_service.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
BundleEntrySearch,
1313
)
1414
from fhir.resources.R4B.immunization import Immunization
15-
from poetry.console.commands import self
1615
from pydantic import ValidationError
1716

1817
import parameter_parser
19-
from authorisation.ApiOperationCode import ApiOperationCode
20-
from authorisation.Authoriser import Authoriser
18+
from authorisation.api_operation_code import ApiOperationCode
19+
from authorisation.authoriser import Authoriser
2120
from fhir_repository import ImmunizationRepository
2221
from models.errors import InvalidPatientId, CustomValidationError, UnauthorizedVaxError
2322
from models.fhir_immunization import ImmunizationValidator
2423
from models.utils.generic_utils import nhs_number_mod11_check, get_occurrence_datetime, create_diagnostics, form_json, get_contained_patient
2524
from models.errors import MandatoryError
25+
from models.utils.validation_utils import get_vaccine_type
2626
from timer import timed
2727
from filter import Filter
2828

@@ -80,16 +80,20 @@ def get_immunization_by_identifier(
8080
imms_resp['resource'] = filtered_resource
8181
return form_json(imms_resp, element, identifier, base_url)
8282

83-
def get_immunization_by_id(self, imms_id: str, imms_vax_type_perms: list[str]) -> Optional[dict]:
83+
def get_immunization_by_id(self, imms_id: str, supplier_system: str) -> Optional[dict]:
8484
"""
8585
Get an Immunization by its ID. Return None if it is not found. If the patient doesn't have an NHS number,
8686
return the Immunization.
8787
"""
88-
if not (imms_resp := self.immunization_repo.get_immunization_by_id(imms_id, imms_vax_type_perms)):
88+
if not (imms_resp := self.immunization_repo.get_immunization_by_id(imms_id)):
8989
return None
9090

9191
# Returns the Immunisation full resource with no obfuscation
9292
resource = imms_resp.get("Resource", {})
93+
vaccination_type = get_vaccine_type(resource)
94+
95+
if not self.authoriser.authorise(supplier_system, ApiOperationCode.READ, {vaccination_type}):
96+
raise UnauthorizedVaxError()
9397

9498
return {
9599
"Version": imms_resp.get("Version", ""),
@@ -109,10 +113,7 @@ def get_immunization_by_id_all(self, imms_id: str, imms: dict) -> Optional[dict]
109113
imms_resp = self.immunization_repo.get_immunization_by_id_all(imms_id, imms)
110114
return imms_resp
111115

112-
def create_immunization(
113-
self, immunization: dict, imms_vax_type_perms, supplier_system
114-
) -> Immunization:
115-
116+
def create_immunization(self, immunization: dict, supplier_system: str) -> dict | Immunization:
116117
if immunization.get("id") is not None:
117118
raise CustomValidationError("id field must not be present for CREATE operation")
118119

@@ -121,14 +122,17 @@ def create_immunization(
121122
except (ValidationError, ValueError, MandatoryError) as error:
122123
raise CustomValidationError(message=str(error)) from error
123124
patient = self._validate_patient(immunization)
125+
124126
if "diagnostics" in patient:
125127
return patient
126128

127-
imms = self.immunization_repo.create_immunization(
128-
immunization, patient, imms_vax_type_perms, supplier_system
129-
)
129+
vaccination_type = get_vaccine_type(immunization)
130130

131-
return Immunization.parse_obj(imms)
131+
if not self.authoriser.authorise(supplier_system, ApiOperationCode.CREATE, {vaccination_type}):
132+
raise UnauthorizedVaxError()
133+
134+
immunisation = self.immunization_repo.create_immunization(immunization, patient, supplier_system)
135+
return Immunization.parse_obj(immunisation)
132136

133137
def update_immunization(
134138
self,
@@ -200,7 +204,7 @@ def update_reinstated_immunization(
200204

201205
return UpdateOutcome.UPDATE, Immunization.parse_obj(imms), updated_version
202206

203-
def delete_immunization(self, imms_id, imms_vax_type_perms, supplier_system) -> Immunization:
207+
def delete_immunization(self, imms_id: str, imms_vax_type_perms, supplier_system: str) -> Immunization:
204208
"""
205209
Delete an Immunization if it exits and return the ID back if successful.
206210
Exception will be raised if resource didn't exit. Multiple calls to this method won't change

backend/tests/test_fhir_controller.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
import unittest
66
import uuid
77

8-
from unittest.mock import patch
98
from fhir.resources.R4B.bundle import Bundle
109
from fhir.resources.R4B.immunization import Immunization
1110
from unittest.mock import create_autospec, ANY, patch, Mock
1211
from urllib.parse import urlencode
1312
import urllib.parse
14-
from moto import mock_aws
1513
from authorization import Authorization
1614
from fhir_controller import FhirController
1715
from fhir_repository import ImmunizationRepository
@@ -22,15 +20,12 @@
2220
InvalidPatientId,
2321
CustomValidationError,
2422
ParameterException,
25-
InconsistentIdError,
2623
UnauthorizedVaxError,
27-
UnauthorizedError,
2824
IdentifierDuplicationError,
2925
)
3026
from tests.utils.immunization_utils import create_covid_19_immunization
3127
from parameter_parser import patient_identifier_system, process_search_params
3228
from tests.utils.generic_utils import load_json_data
33-
from tests.utils.values_for_tests import ValidValues
3429

3530
class TestFhirControllerBase(unittest.TestCase):
3631
"""Base class for all tests to set up common fixtures"""
@@ -211,7 +206,7 @@ def test_get_imms_by_identifer_patient_identifier_and_element_present(self, mock
211206
# When
212207
response = self.controller.get_immunization_by_identifier(lambda_event)
213208
# Then
214-
self.service.get_immunization_by_identifier.assert_not_called
209+
self.service.get_immunization_by_identifier.assert_not_called()
215210

216211
self.assertEqual(response["statusCode"], 400)
217212
body = json.loads(response["body"])
@@ -234,7 +229,30 @@ def test_get_imms_by_identifer_both_body_and_query_params_present(self, mock_get
234229
# When
235230
response = self.controller.get_immunization_by_identifier(lambda_event)
236231
# Then
237-
self.service.get_immunization_by_identifier.assert_not_called
232+
self.service.get_immunization_by_identifier.assert_not_called()
233+
234+
self.assertEqual(response["statusCode"], 400)
235+
body = json.loads(response["body"])
236+
self.assertEqual(body["resourceType"], "OperationOutcome")
237+
238+
@patch("fhir_controller.get_supplier_permissions")
239+
def test_get_imms_by_identifer_imms_identifier_and_element_not_present(self, mock_get_supplier_permissions):
240+
"""it should return Immunization Id if it exists"""
241+
# Given
242+
mock_get_supplier_permissions.return_value = ["COVID19.CRUDS"]
243+
self.service.get_immunization_by_identifier.return_value = {"id": "test", "Version": 1}
244+
lambda_event = {
245+
"headers": {"SupplierSystem": "test"},
246+
"queryStringParameters": {
247+
"-immunization.target": "test",
248+
"immunization.identifier": "https://supplierABC/identifiers/vacc|f10b59b3-fc73-4616-99c9-9e882ab31184",
249+
},
250+
"body": None,
251+
}
252+
# When
253+
response = self.controller.get_immunization_by_identifier(lambda_event)
254+
# Then
255+
self.service.get_immunization_by_identifier.assert_not_called()
238256

239257
self.assertEqual(response["statusCode"], 400)
240258
body = json.loads(response["body"])
@@ -258,7 +276,7 @@ def test_get_imms_by_identifer_both_identifier_present(self, mock_get_supplier_p
258276
# When
259277
response = self.controller.get_immunization_by_identifier(lambda_event)
260278
# Then
261-
self.service.get_immunization_by_identifier.assert_not_called
279+
self.service.get_immunization_by_identifier.assert_not_called()
262280

263281
self.assertEqual(response["statusCode"], 400)
264282
body = json.loads(response["body"])
@@ -512,7 +530,7 @@ def test_get_imms_by_identifer_patient_identifier_and_element_present(self, mock
512530
# When
513531
response = self.controller.get_immunization_by_identifier(lambda_event)
514532
# Then
515-
self.service.get_immunization_by_identifier.assert_not_called
533+
self.service.get_immunization_by_identifier.assert_not_called()
516534

517535
self.assertEqual(response["statusCode"], 400)
518536
body = json.loads(response["body"])
@@ -532,7 +550,7 @@ def test_get_imms_by_identifer_imms_identifier_and_element_not_present(self,mock
532550
# When
533551
response = self.controller.get_immunization_by_identifier(lambda_event)
534552
# Then
535-
self.service.get_immunization_by_identifier.assert_not_called
553+
self.service.get_immunization_by_identifier.assert_not_called()
536554

537555
self.assertEqual(response["statusCode"], 400)
538556
body = json.loads(response["body"])
@@ -616,7 +634,7 @@ def test_get_imms_by_identifer_both_identifier_present(self, mock_get_supplier_p
616634
# When
617635
response = self.controller.get_immunization_by_identifier(lambda_event)
618636
# Then
619-
self.service.get_immunization_by_identifier.assert_not_called
637+
self.service.get_immunization_by_identifier.assert_not_called()
620638

621639
self.assertEqual(response["statusCode"], 400)
622640
body = json.loads(response["body"])
@@ -757,7 +775,7 @@ def test_get_imms_by_id(self, mock_permissions):
757775
response = self.controller.get_immunization_by_id(lambda_event)
758776
# Then
759777
mock_permissions.assert_called_once_with("test")
760-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
778+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test")
761779

762780
self.assertEqual(response["statusCode"], 200)
763781
body = json.loads(response["body"])
@@ -780,6 +798,7 @@ def test_get_imms_by_id_unauthorised_vax_error(self,mock_permissions):
780798
# Then
781799
mock_permissions.assert_called_once_with("test")
782800
self.assertEqual(response["statusCode"], 403)
801+
783802
@patch("fhir_controller.get_supplier_permissions")
784803
def test_get_imms_by_id_no_vax_permission(self, mock_permissions):
785804
"""it should return Immunization Id if it exists"""
@@ -814,7 +833,7 @@ def test_not_found(self,mock_permissions):
814833

815834
# Then
816835
mock_permissions.assert_called_once_with("test")
817-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
836+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "test")
818837

819838
self.assertEqual(response["statusCode"], 404)
820839
body = json.loads(response["body"])
@@ -860,7 +879,7 @@ def test_create_immunization(self,mock_get_permissions):
860879

861880
imms_obj = json.loads(aws_event["body"])
862881
mock_get_permissions.assert_called_once_with("Test")
863-
self.service.create_immunization.assert_called_once_with(imms_obj, ["COVID19.CRUDS", "FLU.CRUDS"], "Test")
882+
self.service.create_immunization.assert_called_once_with(imms_obj, "Test")
864883
self.assertEqual(response["statusCode"], 201)
865884
self.assertTrue("body" not in response)
866885
self.assertTrue(response["headers"]["Location"].endswith(f"Immunization/{imms_id}"))
@@ -1271,7 +1290,7 @@ def test_update_record_exists(self, mock_get_supplier_permissions):
12711290
response = self.controller.get_immunization_by_id(lambda_event)
12721291

12731292
# Then
1274-
self.service.get_immunization_by_id.assert_called_once_with(imms_id, ["COVID19.CRUDS"])
1293+
self.service.get_immunization_by_id.assert_called_once_with(imms_id, "Test")
12751294

12761295
self.assertEqual(response["statusCode"], 404)
12771296
body = json.loads(response["body"])

0 commit comments

Comments
 (0)