Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions app/Fixtures/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
import os


"""
App root directory. This is dependent on constants.py being one directory
down from the app directory.
"""
APP_ROOT_DIR = os.path.split(
os.path.dirname(os.path.abspath(__file__))
)[0]

RANDOM_SEED = 1067641072

WINSOR_THRESHOLDS = {
Expand Down
10 changes: 7 additions & 3 deletions app/Fixtures/gams.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pickle
import os, pickle
from constants import APP_ROOT_DIR

study_export = pickle.load(open("app/Fixtures/production_assets.pkl", "rb"))
study_export = pickle.load(open(
os.path.join(APP_ROOT_DIR, 'Fixtures', 'production_assets.pkl'),
"rb"
))

MORTALTIY_GAM = study_export["mortality"]["model"]
MORTALITY_GAM = study_export["mortality"]["model"]
LACTATE_GAM = study_export["lactate"]["model"]
ALBUMIN_GAM = study_export["albumin"]["model"]

Expand Down
4 changes: 2 additions & 2 deletions app/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.preprocessing import QuantileTransformer
from pygam import GAM, LinearGAM
from pygam.distributions import NormalDist
from app.Fixtures.gams import MORTALTIY_GAM
from app.Fixtures.gams import MORTALITY_GAM


def quick_sample(
Expand Down Expand Up @@ -142,7 +142,7 @@ def predict_mortality(
(features.shape[0] * n_samples_per_row,)
"""
return quick_sample(
gam=MORTALTIY_GAM,
gam=MORTALITY_GAM,
sample_at_X=features,
quantity="mu",
n_draws=n_samples_per_row,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import app.prediction.predict as predict
from app.Fixtures.gams import LACTATE_GAM, LACTATE_TRANSFORMER
from app.Fixtures.gams import MORTALTIY_GAM
from app.Fixtures.gams import MORTALITY_GAM


def lineargam_data(n_rows: int) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
115 changes: 104 additions & 11 deletions tests/test_predict_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
from fastapi.testclient import TestClient
from app.main import api
from app.Fixtures.gams import study_export
from app.prediction.predict import predict_mortality
from app.Fixtures.constants import RANDOM_SEED

client = TestClient(api)

Expand All @@ -9,7 +13,7 @@ def test_index():
assert response.status_code == 200


pred = {
patient1 = {
"Age": 40,
"ASA": 3,
"HR": 87,
Expand All @@ -31,9 +35,57 @@ def test_index():
}


# An example patient with observed lactate & albumin, Winsorisation not required
patient2 = {
"Age": 81,
"ASA": 2,
"HR": 82,
"SBP": 104,
"WCC": 9.1,
"Na": 135,
"K": 4.4,
"Urea": 8.7,
"Creat": 78,
"GCS": 15,
"Resp": 0,
"Cardio": 1,
"Sinus": 0,
"CT_performed": 1,
"Indication": 0,
"Malignancy": 0,
"Soiling": 1,
"Lactate": 3.2,
"Albumin": 25
}


# Keys = API variable names, values = corresponding NELA variable names
api_nela_var_map = {
"Age": "S01AgeOnArrival",
"ASA": "S03ASAScore",
"HR": "S03Pulse",
"SBP": "S03SystolicBloodPressure",
"WCC": "S03WhiteCellCount",
"Na": "S03Sodium",
"K": "S03Potassium",
"Urea": "S03Urea",
"Creat": "S03SerumCreatinine",
"GCS": 'S03GlasgowComaScore',
"Resp": 'S03RespiratorySigns',
"Cardio": 'S03CardiacSigns',
"Sinus": "S03ECG",
"CT_performed": "S02PreOpCTPerformed",
"Indication": "Indication",
"Malignancy": 'S03DiagnosedMalignancy',
"Soiling": 'S03Pred_Peritsoil',
"Lactate": 'S03PreOpArterialBloodLactate',
"Albumin": 'S03PreOpLowestAlbumin'
}


def test_predict_api_both_impute():
response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=pred
"/predict", headers={"Content-Type": "application/json"}, json=patient1
)
assert response.status_code == 200

Expand All @@ -42,10 +94,10 @@ def test_predict_api_both_impute():


def test_predict_api_alb_impute():
pred["Albumin"] = 40
patient1["Albumin"] = 40

response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=pred
"/predict", headers={"Content-Type": "application/json"}, json=patient1
)
assert response.status_code == 200

Expand All @@ -54,10 +106,10 @@ def test_predict_api_alb_impute():


def test_predict_api_basic():
pred["Lactate"] = 1
patient1["Lactate"] = 1

response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=pred
"/predict", headers={"Content-Type": "application/json"}, json=patient1
)
assert response.status_code == 200

Expand All @@ -70,19 +122,60 @@ def test_predict_api_basic():


def test_predict_api_invalid_cat():
pred["Soiling"] = 7
patient1["Soiling"] = 7
response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=pred
"/predict", headers={"Content-Type": "application/json"}, json=patient1
)

assert response.status_code == 422


def test_predict_api_invalid_type():
pred["soiling"] = 1
pred["SBP"] = 103.4
patient1["soiling"] = 1
patient1["SBP"] = 103.4
response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=pred
"/predict", headers={"Content-Type": "application/json"}, json=patient1
)

assert response.status_code == 422


def test_predict_api_vs_direct_prediction():
"""
Compares mortality risk predictions from the predict API, to those
generated by direct use of predict_mortality() in a patient that doesn't
require Winsorisation or lactate / albumin imputation. These should be the
same.
"""
# Get API mortality risk prediction
response = client.post(
"/predict", headers={"Content-Type": "application/json"}, json=patient2
)
assert response.status_code == 200
api_pred = np.array(response.json()["Result"])

# Get a 1-row DataFrame with same columns as input to predict_mortality()
features = study_export['mortality']['input_data']['describe'].iloc[
5:6
].reset_index(drop=True)

# Replace values with those from example patient 2, and add missingness vars
direct_patient = {}
for api_name, value in patient2.items():
direct_patient[api_nela_var_map[api_name]] = value
direct_patient["S03PreOpLowestAlbumin_missing"] = 0.
direct_patient['S03PreOpArterialBloodLactate_missing'] = 0.

# Get direct mortality risk prediction
direct_pred = predict_mortality(
features=features,
n_samples_per_row=api_pred.size,
random_seed=RANDOM_SEED
)

# Compare predictions
# TODO: We need to round for the test to pass - why so much numerical error?
decimal_places = 4
assert (
direct_pred.round(decimal_places) == api_pred.round(decimal_places)
).all()