diff --git a/clinical_model/clinical_data.ipynb b/clinical_model/clinical_data.ipynb
new file mode 100644
index 0000000..210dbd1
--- /dev/null
+++ b/clinical_model/clinical_data.ipynb
@@ -0,0 +1,2154 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "electoral-korean",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import pathlib\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "boolean-bacteria",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " Fibrinogen | \n",
+ " LDH | \n",
+ " Ddimer | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " SaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0 | \n",
+ " 39.3 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 651.0 | \n",
+ " 387.0 | \n",
+ " 157.0 | \n",
+ " 94.0 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0 | \n",
+ " 37.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 700.0 | \n",
+ " 338.0 | \n",
+ " 601.0 | \n",
+ " 94.0 | \n",
+ " 75.0 | \n",
+ " 96.9 | \n",
+ " 7.42 | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0 | \n",
+ " 37.8 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 698.0 | \n",
+ " 356.0 | \n",
+ " 448.0 | \n",
+ " 94.0 | \n",
+ " 63.0 | \n",
+ " 94.6 | \n",
+ " 7.39 | \n",
+ " 1.0 | \n",
+ " NaN | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0 | \n",
+ " 38.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 513.0 | \n",
+ " 482.0 | \n",
+ " NaN | \n",
+ " 97.0 | \n",
+ " 68.0 | \n",
+ " 96.3 | \n",
+ " 7.46 | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1 | \n",
+ " 37.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 210.0 | \n",
+ " 93.0 | \n",
+ " NaN | \n",
+ " 97.3 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 550.0 | \n",
+ " 368.0 | \n",
+ " 5027.0 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 757.0 | \n",
+ " 451.0 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 1073.0 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1 | \n",
+ " 38.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 533.0 | \n",
+ " 358.0 | \n",
+ " 2154.0 | \n",
+ " NaN | \n",
+ " 75.7 | \n",
+ " NaN | \n",
+ " 7.36 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0 | \n",
+ " 38.4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " NaN | \n",
+ " 149.0 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 80.0 | \n",
+ " NaN | \n",
+ " 7.46 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 17 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0 39.3 1.0 0.0 5.76 \n",
+ "P_132 57.266667 0 37.0 0.0 0.0 11.48 \n",
+ "P_195 79.263889 0 37.8 1.0 0.0 6.21 \n",
+ "P_193 82.000000 0 38.0 1.0 0.0 7.28 \n",
+ "P_140 60.791667 1 37.0 1.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0 NaN 0.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0 NaN 1.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0 NaN 0.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1 38.0 0.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0 38.4 0.0 0.0 6.70 \n",
+ "\n",
+ " CRP Fibrinogen LDH Ddimer Ox_percentage PaO2 SaO2 pH \\\n",
+ "PatientID \n",
+ "P_131 43.40 651.0 387.0 157.0 94.0 NaN NaN NaN \n",
+ "P_132 64.00 700.0 338.0 601.0 94.0 75.0 96.9 7.42 \n",
+ "P_195 115.30 698.0 356.0 448.0 94.0 63.0 94.6 7.39 \n",
+ "P_193 149.30 513.0 482.0 NaN 97.0 68.0 96.3 7.46 \n",
+ "P_140 20.70 NaN NaN 210.0 93.0 NaN 97.3 NaN \n",
+ "... ... ... ... ... ... ... ... ... \n",
+ "P_1_12 22.79 550.0 368.0 5027.0 NaN NaN NaN NaN \n",
+ "P_1_8 8.93 757.0 451.0 NaN NaN NaN NaN NaN \n",
+ "P_1_10 0.23 NaN NaN 1073.0 NaN NaN NaN NaN \n",
+ "P_1_26 3.77 533.0 358.0 2154.0 NaN 75.7 NaN 7.36 \n",
+ "P_1_146 0.02 NaN 149.0 NaN NaN 80.0 NaN 7.46 \n",
+ "\n",
+ " CardiovascularDisease RespiratoryFailure Prognosis \n",
+ "PatientID \n",
+ "P_131 0.0 NaN 0 \n",
+ "P_132 0.0 NaN 0 \n",
+ "P_195 1.0 NaN 1 \n",
+ "P_193 0.0 NaN 1 \n",
+ "P_140 0.0 NaN 0 \n",
+ "... ... ... ... \n",
+ "P_1_12 0.0 0.0 1 \n",
+ "P_1_8 0.0 0.0 1 \n",
+ "P_1_10 0.0 0.0 0 \n",
+ "P_1_26 0.0 0.0 1 \n",
+ "P_1_146 1.0 0.0 0 \n",
+ "\n",
+ "[863 rows x 17 columns]"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_path = pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/trainSet.txt\")\n",
+ "df_train = pd.read_csv(train_path)\n",
+ "df_train.set_index(\"PatientID\", inplace=True)\n",
+ "df_train.drop([\"ImageFile\", \"Hospital\"], axis=1, inplace=True)\n",
+ "df_train[\"Prognosis\"] = (df_train[\"Prognosis\"] == \"SEVERE\").astype(np.uint8)\n",
+ "\n",
+ "# one-hot encode the hospital\n",
+ "#foo = pd.get_dummies(df_train[\"Hospital\"], prefix=\"Hospital\")\n",
+ "#df_train = pd.concat([df_train, foo], axis=1)\n",
+ "#df_train = df_train.drop([\"Hospital\"], axis=1)\n",
+ "\n",
+ "\n",
+ "df_train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "fantastic-struggle",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(6, 17)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# this is how many complete samples we have\n",
+ "df_train.dropna().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eed10a30",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ongoing-dance",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Age \t\t\t 1 / 863\n",
+ "Sex \t\t\t 0 / 863\n",
+ "Temp_C \t\t\t 154 / 863\n",
+ "Cough \t\t\t 5 / 863\n",
+ "DifficultyInBreathing \t\t\t 4 / 863\n",
+ "WBC \t\t\t 9 / 863\n",
+ "CRP \t\t\t 33 / 863\n",
+ "Fibrinogen \t\t\t 591 / 863\n",
+ "LDH \t\t\t 136 / 863\n",
+ "Ddimer \t\t\t 621 / 863\n",
+ "Ox_percentage \t\t\t 243 / 863\n",
+ "PaO2 \t\t\t 170 / 863\n",
+ "SaO2 \t\t\t 583 / 863\n",
+ "pH \t\t\t 207 / 863\n",
+ "CardiovascularDisease \t\t\t 19 / 863\n",
+ "RespiratoryFailure \t\t\t 159 / 863\n",
+ "Prognosis \t\t\t 0 / 863\n",
+ "\n",
+ "Going to discard features with more than 200 nans: ['Fibrinogen', 'Ddimer', 'Ox_percentage', 'SaO2', 'pH']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# how many samples are missing for individual columns\n",
+ "discard = []\n",
+ "discard_thresh = 200\n",
+ "for c in df_train.columns:\n",
+ " n_missing = df_train[c].isna().sum()\n",
+ " print(c, \"\\t\\t\\t\", n_missing, \"/\", len(df_train))\n",
+ " if n_missing > discard_thresh:\n",
+ " discard.append(c)\n",
+ " \n",
+ "print(f\"\\nGoing to discard features with more than {discard_thresh} nans:\", discard)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "induced-treat",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " PaO2 | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0 | \n",
+ " 39.3 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 387.0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0 | \n",
+ " 37.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 338.0 | \n",
+ " 75.0 | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0 | \n",
+ " 37.8 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 356.0 | \n",
+ " 63.0 | \n",
+ " 1.0 | \n",
+ " NaN | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0 | \n",
+ " 38.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 482.0 | \n",
+ " 68.0 | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1 | \n",
+ " 37.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 368.0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 451.0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0 | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1 | \n",
+ " 38.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.0 | \n",
+ " 75.7 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0 | \n",
+ " 38.4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.0 | \n",
+ " 80.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 12 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0 39.3 1.0 0.0 5.76 \n",
+ "P_132 57.266667 0 37.0 0.0 0.0 11.48 \n",
+ "P_195 79.263889 0 37.8 1.0 0.0 6.21 \n",
+ "P_193 82.000000 0 38.0 1.0 0.0 7.28 \n",
+ "P_140 60.791667 1 37.0 1.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0 NaN 0.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0 NaN 1.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0 NaN 0.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1 38.0 0.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0 38.4 0.0 0.0 6.70 \n",
+ "\n",
+ " CRP LDH PaO2 CardiovascularDisease RespiratoryFailure \\\n",
+ "PatientID \n",
+ "P_131 43.40 387.0 NaN 0.0 NaN \n",
+ "P_132 64.00 338.0 75.0 0.0 NaN \n",
+ "P_195 115.30 356.0 63.0 1.0 NaN \n",
+ "P_193 149.30 482.0 68.0 0.0 NaN \n",
+ "P_140 20.70 NaN NaN 0.0 NaN \n",
+ "... ... ... ... ... ... \n",
+ "P_1_12 22.79 368.0 NaN 0.0 0.0 \n",
+ "P_1_8 8.93 451.0 NaN 0.0 0.0 \n",
+ "P_1_10 0.23 NaN NaN 0.0 0.0 \n",
+ "P_1_26 3.77 358.0 75.7 0.0 0.0 \n",
+ "P_1_146 0.02 149.0 80.0 1.0 0.0 \n",
+ "\n",
+ " Prognosis \n",
+ "PatientID \n",
+ "P_131 0 \n",
+ "P_132 0 \n",
+ "P_195 1 \n",
+ "P_193 1 \n",
+ "P_140 0 \n",
+ "... ... \n",
+ "P_1_12 1 \n",
+ "P_1_8 1 \n",
+ "P_1_10 0 \n",
+ "P_1_26 1 \n",
+ "P_1_146 0 \n",
+ "\n",
+ "[863 rows x 12 columns]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_reduced = df_train.drop(discard, axis=1)\n",
+ "df_reduced"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bizarre-college",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " PaO2 | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " count | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ " 417.000000 | \n",
+ "
\n",
+ " \n",
+ " mean | \n",
+ " 65.167866 | \n",
+ " 0.338129 | \n",
+ " 37.553717 | \n",
+ " 0.522782 | \n",
+ " 0.489209 | \n",
+ " 6.953189 | \n",
+ " 32.663094 | \n",
+ " 370.949640 | \n",
+ " 73.613669 | \n",
+ " 0.294964 | \n",
+ " 0.014388 | \n",
+ " 0.482014 | \n",
+ "
\n",
+ " \n",
+ " std | \n",
+ " 14.489739 | \n",
+ " 0.473641 | \n",
+ " 0.968887 | \n",
+ " 0.500081 | \n",
+ " 0.500484 | \n",
+ " 3.279598 | \n",
+ " 54.263935 | \n",
+ " 229.063855 | \n",
+ " 26.501603 | \n",
+ " 0.456574 | \n",
+ " 0.119229 | \n",
+ " 0.500277 | \n",
+ "
\n",
+ " \n",
+ " min | \n",
+ " 22.000000 | \n",
+ " 0.000000 | \n",
+ " 35.400000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.200000 | \n",
+ " 0.010000 | \n",
+ " 115.000000 | \n",
+ " 23.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 25% | \n",
+ " 56.000000 | \n",
+ " 0.000000 | \n",
+ " 36.800000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 4.700000 | \n",
+ " 5.360000 | \n",
+ " 247.000000 | \n",
+ " 60.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 50% | \n",
+ " 65.000000 | \n",
+ " 0.000000 | \n",
+ " 37.700000 | \n",
+ " 1.000000 | \n",
+ " 0.000000 | \n",
+ " 6.300000 | \n",
+ " 14.420000 | \n",
+ " 318.000000 | \n",
+ " 70.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 75% | \n",
+ " 78.000000 | \n",
+ " 1.000000 | \n",
+ " 38.200000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 8.460000 | \n",
+ " 33.960000 | \n",
+ " 422.000000 | \n",
+ " 80.000000 | \n",
+ " 1.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ " max | \n",
+ " 96.000000 | \n",
+ " 1.000000 | \n",
+ " 40.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 26.500000 | \n",
+ " 413.000000 | \n",
+ " 2578.000000 | \n",
+ " 285.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing \\\n",
+ "count 417.000000 417.000000 417.000000 417.000000 417.000000 \n",
+ "mean 65.167866 0.338129 37.553717 0.522782 0.489209 \n",
+ "std 14.489739 0.473641 0.968887 0.500081 0.500484 \n",
+ "min 22.000000 0.000000 35.400000 0.000000 0.000000 \n",
+ "25% 56.000000 0.000000 36.800000 0.000000 0.000000 \n",
+ "50% 65.000000 0.000000 37.700000 1.000000 0.000000 \n",
+ "75% 78.000000 1.000000 38.200000 1.000000 1.000000 \n",
+ "max 96.000000 1.000000 40.000000 1.000000 1.000000 \n",
+ "\n",
+ " WBC CRP LDH PaO2 CardiovascularDisease \\\n",
+ "count 417.000000 417.000000 417.000000 417.000000 417.000000 \n",
+ "mean 6.953189 32.663094 370.949640 73.613669 0.294964 \n",
+ "std 3.279598 54.263935 229.063855 26.501603 0.456574 \n",
+ "min 0.200000 0.010000 115.000000 23.000000 0.000000 \n",
+ "25% 4.700000 5.360000 247.000000 60.000000 0.000000 \n",
+ "50% 6.300000 14.420000 318.000000 70.000000 0.000000 \n",
+ "75% 8.460000 33.960000 422.000000 80.000000 1.000000 \n",
+ "max 26.500000 413.000000 2578.000000 285.000000 1.000000 \n",
+ "\n",
+ " RespiratoryFailure Prognosis \n",
+ "count 417.000000 417.000000 \n",
+ "mean 0.014388 0.482014 \n",
+ "std 0.119229 0.500277 \n",
+ "min 0.000000 0.000000 \n",
+ "25% 0.000000 0.000000 \n",
+ "50% 0.000000 0.000000 \n",
+ "75% 0.000000 1.000000 \n",
+ "max 1.000000 1.000000 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_reduced = df_reduced.dropna()\n",
+ "df_reduced.describe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "compact-lawsuit",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " PaO2 | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_302 | \n",
+ " 56.0 | \n",
+ " 0 | \n",
+ " 38.3 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 7.85 | \n",
+ " 9.00 | \n",
+ " 159.0 | \n",
+ " 73.2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_301 | \n",
+ " 61.0 | \n",
+ " 0 | \n",
+ " 38.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 3.57 | \n",
+ " 57.40 | \n",
+ " 309.0 | \n",
+ " 57.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_268 | \n",
+ " 64.0 | \n",
+ " 0 | \n",
+ " 36.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 5.26 | \n",
+ " 41.90 | \n",
+ " 299.0 | \n",
+ " 72.2 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_282 | \n",
+ " 54.0 | \n",
+ " 0 | \n",
+ " 36.6 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 7.01 | \n",
+ " 67.90 | \n",
+ " 458.0 | \n",
+ " 61.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_300 | \n",
+ " 44.0 | \n",
+ " 0 | \n",
+ " 37.7 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.18 | \n",
+ " 42.70 | \n",
+ " 243.0 | \n",
+ " 23.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_77 | \n",
+ " 61.0 | \n",
+ " 1 | \n",
+ " 36.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 3.60 | \n",
+ " 0.07 | \n",
+ " 219.0 | \n",
+ " 86.2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_110 | \n",
+ " 56.0 | \n",
+ " 0 | \n",
+ " 36.5 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 7.00 | \n",
+ " 13.51 | \n",
+ " 411.0 | \n",
+ " 75.8 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_126 | \n",
+ " 57.0 | \n",
+ " 0 | \n",
+ " 37.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 4.75 | \n",
+ " 2.47 | \n",
+ " 383.0 | \n",
+ " 61.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.0 | \n",
+ " 1 | \n",
+ " 38.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.0 | \n",
+ " 75.7 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.0 | \n",
+ " 0 | \n",
+ " 38.4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.0 | \n",
+ " 80.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
417 rows × 11 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC CRP \\\n",
+ "PatientID \n",
+ "P_302 56.0 0 38.3 0.0 1.0 7.85 9.00 \n",
+ "P_301 61.0 0 38.0 1.0 1.0 3.57 57.40 \n",
+ "P_268 64.0 0 36.0 1.0 0.0 5.26 41.90 \n",
+ "P_282 54.0 0 36.6 1.0 0.0 7.01 67.90 \n",
+ "P_300 44.0 0 37.7 0.0 0.0 9.18 42.70 \n",
+ "... ... ... ... ... ... ... ... \n",
+ "P_1_77 61.0 1 36.0 1.0 0.0 3.60 0.07 \n",
+ "P_1_110 56.0 0 36.5 1.0 0.0 7.00 13.51 \n",
+ "P_1_126 57.0 0 37.0 1.0 1.0 4.75 2.47 \n",
+ "P_1_26 92.0 1 38.0 0.0 1.0 13.50 3.77 \n",
+ "P_1_146 80.0 0 38.4 0.0 0.0 6.70 0.02 \n",
+ "\n",
+ " LDH PaO2 CardiovascularDisease RespiratoryFailure \n",
+ "PatientID \n",
+ "P_302 159.0 73.2 0.0 0.0 \n",
+ "P_301 309.0 57.0 0.0 0.0 \n",
+ "P_268 299.0 72.2 1.0 0.0 \n",
+ "P_282 458.0 61.0 0.0 0.0 \n",
+ "P_300 243.0 23.0 0.0 0.0 \n",
+ "... ... ... ... ... \n",
+ "P_1_77 219.0 86.2 0.0 0.0 \n",
+ "P_1_110 411.0 75.8 0.0 0.0 \n",
+ "P_1_126 383.0 61.0 0.0 0.0 \n",
+ "P_1_26 358.0 75.7 0.0 0.0 \n",
+ "P_1_146 149.0 80.0 1.0 0.0 \n",
+ "\n",
+ "[417 rows x 11 columns]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "label_col = \"Prognosis\"\n",
+ "\n",
+ "y = df_reduced[[label_col]]\n",
+ "X_df = df_reduced[[c for c in df_reduced.columns if c not in [label_col]]]\n",
+ "\n",
+ "X_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "hollow-match",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "picked_features = X_df.columns\n",
+ "\n",
+ "#picked_features = [\n",
+ "# \"LDH\",\n",
+ "# \"Age\",\n",
+ "# \"PaO2\",\n",
+ "# \"CRP\",\n",
+ "# \"WBC\",\n",
+ "# \"pH\"\n",
+ "#]\n",
+ "\n",
+ "\n",
+ "X = X_df[picked_features]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "26719885",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# now also load the test data for which maksim imputed the values\n",
+ "\n",
+ "test_path = pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/testSet_filled2.txt\")\n",
+ "df_test = pd.read_csv(test_path)\n",
+ "df_test.set_index(\"PatientID\", inplace=True)\n",
+ "#df_test.drop([\"Unnamed: 0\"], axis=1, inplace=True)\n",
+ "# get rid of the features we dumped from the training set as well\n",
+ "df_test.drop([c for c in df_test.columns if c not in df_reduced.columns], axis=1, inplace=True)\n",
+ "# one-hot encode the hospital\n",
+ "#foo = pd.get_dummies(df_test[\"Hospital\"], prefix=\"Hospital\")\n",
+ "#df_test = pd.concat([df_test, foo], axis=1)\n",
+ "#df_test = df_test.drop([\"Hospital\"], axis=1)\n",
+ "assert sorted(df_test.columns) == sorted(X.columns)\n",
+ "\n",
+ "X_test = df_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "3cec65d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_102.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_117.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_16.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_118.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_114.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_88.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_92.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_86.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_9.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_90.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
120 rows × 1 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Prognosis\n",
+ "PatientID \n",
+ "P_102.png 1\n",
+ "P_117.png 0\n",
+ "P_16.png 1\n",
+ "P_118.png 0\n",
+ "P_114.png 1\n",
+ "... ...\n",
+ "P_88.png 0\n",
+ "P_92.png 1\n",
+ "P_86.png 1\n",
+ "P_9.png 1\n",
+ "P_90.png 1\n",
+ "\n",
+ "[120 rows x 1 columns]"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_test = pd.read_csv(pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/solutionTestSet.txt\"))\n",
+ "y_test.set_index(\"ImageFile\", inplace=True)\n",
+ "y_test.index.rename(\"PatientID\", inplace=True)\n",
+ "y_test[\"Prognosis\"] = (y_test[\"Prognosis\"] == \"SEVERE\").astype(np.uint8)\n",
+ "y_test = y_test[[\"Prognosis\"]]\n",
+ "\n",
+ "y_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "placed-ratio",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import MinMaxScaler, StandardScaler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "applicable-white",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "312 105\n"
+ ]
+ }
+ ],
+ "source": [
+ "train, valid = train_test_split(X_df.index.to_numpy(), test_size=.25, stratify=y[label_col], random_state=52)\n",
+ "print(len(train), len(valid))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "f230c6c9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['P_384' 'P_417' 'P_551' 'P_479' 'P_555' 'P_376' 'P_1_26' 'P_379' 'P_441'\n",
+ " 'P_450' 'P_760' 'P_756' 'P_435' 'P_442' 'P_766' 'P_695' 'P_537' 'P_687'\n",
+ " 'P_467' 'P_1_126' 'P_1_75' 'P_653' 'P_1_149' 'P_411' 'P_358' 'P_740'\n",
+ " 'P_1_81' 'P_634' 'P_616' 'P_382' 'P_342' 'P_519' 'P_476' 'P_639' 'P_759'\n",
+ " 'P_350' 'P_484' 'P_635' 'P_507' 'P_439' 'P_798' 'P_598' 'P_414' 'P_583'\n",
+ " 'P_590' 'P_485' 'P_511' 'P_426' 'P_304' 'P_842' 'P_797' 'P_837' 'P_314'\n",
+ " 'P_572' 'P_641' 'P_803' 'P_1_55' 'P_702' 'P_691' 'P_518' 'P_1_95' 'P_521'\n",
+ " 'P_589' 'P_840' 'P_1_131' 'P_772' 'P_368' 'P_292' 'P_750' 'P_274'\n",
+ " 'P_1_146' 'P_650' 'P_505' 'P_733' 'P_623' 'P_697' 'P_620' 'P_1_140'\n",
+ " 'P_609' 'P_808' 'P_526' 'P_383' 'P_595' 'P_587' 'P_580' 'P_502' 'P_1_22'\n",
+ " 'P_324' 'P_1_72' 'P_725' 'P_336' 'P_491' 'P_781' 'P_407' 'P_546' 'P_645'\n",
+ " 'P_425' 'P_735' 'P_445' 'P_812' 'P_385' 'P_1_66' 'P_743' 'P_560' 'P_633']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(valid)\n",
+ "pd.DataFrame(valid).to_csv(\"/home/starke88/git/covid_data_challenge/clinical_model/valid_ids.csv\", index=False, header=None)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "advisory-insight",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array(['P_428', 'P_399', 'P_499', 'P_282'], dtype=object)"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train[:4]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "convertible-function",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(312, 11) (312, 1)\n",
+ "(105, 11) (105, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "mask_train = X.index.isin(train)\n",
+ "feat_names = X.columns\n",
+ "\n",
+ "X_train = X[mask_train]\n",
+ "y_train = y[mask_train]\n",
+ "\n",
+ "X_valid = X[~mask_train]\n",
+ "y_valid = y[~mask_train]\n",
+ "\n",
+ "print(X_train.shape, y_train.shape)\n",
+ "print(X_valid.shape, y_valid.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "posted-mailing",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Age float64\n",
+ "Sex int64\n",
+ "Temp_C float64\n",
+ "Cough float64\n",
+ "DifficultyInBreathing float64\n",
+ "WBC float64\n",
+ "CRP float64\n",
+ "LDH float64\n",
+ "PaO2 float64\n",
+ "CardiovascularDisease float64\n",
+ "RespiratoryFailure float64\n",
+ "dtype: object"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train.dtypes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "altered-payday",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_patient_ids(df):\n",
+ " return df.index.to_numpy().squeeze().tolist()\n",
+ "\n",
+ "scaler = MinMaxScaler()\n",
+ "#scaler = StandardScaler()\n",
+ "\n",
+ "ids_train = get_patient_ids(X_train)\n",
+ "X_train = scaler.fit_transform(X_train.to_numpy())\n",
+ "y_train = y_train.to_numpy().squeeze()\n",
+ "\n",
+ "ids_valid = get_patient_ids(X_valid)\n",
+ "X_valid = scaler.transform(X_valid.to_numpy())\n",
+ "y_valid = y_valid.to_numpy().squeeze()\n",
+ "\n",
+ "ids_test = get_patient_ids(X_test)\n",
+ "X_test = scaler.transform(X_test.to_numpy())\n",
+ "y_test = y_test.to_numpy().squeeze()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "78b2ce16",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(312, 105, 120)"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(ids_train), len(ids_valid), len(ids_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "stuffed-connecticut",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(312, 11) (312,) 312\n",
+ "(105, 11) (105,) 105\n",
+ "(120, 11) (120,) 120\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(X_train.shape, y_train.shape, len(ids_train))\n",
+ "print(X_valid.shape, y_valid.shape, len(ids_valid))\n",
+ "print(X_test.shape, y_test.shape, len(ids_test))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "instant-seating",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'Age': 4.438642485192933, 'Sex': -0.5380532156403621, 'Temp_C': 0.9903044054257356, 'Cough': -0.09706348594397517, 'DifficultyInBreathing': 0.7928339789391249, 'WBC': 0.9729638366525475, 'CRP': 1.5618131479569666, 'LDH': 16.805483182146066, 'PaO2': -0.1568112199646922, 'CardiovascularDisease': -0.2591747799837606, 'RespiratoryFailure': -0.7772138167663587}\n",
+ "\n",
+ "significant coeffs {'Age': 4.438642485192933, 'Sex': -0.5380532156403621, 'Temp_C': 0.9903044054257356, 'Cough': -0.09706348594397517, 'DifficultyInBreathing': 0.7928339789391249, 'WBC': 0.9729638366525475, 'CRP': 1.5618131479569666, 'LDH': 16.805483182146066, 'PaO2': -0.1568112199646922, 'CardiovascularDisease': -0.2591747799837606, 'RespiratoryFailure': -0.7772138167663587}\n",
+ "\n",
+ "Training performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.77 0.79 0.78 162\n",
+ " 1 0.77 0.74 0.75 150\n",
+ "\n",
+ " accuracy 0.77 312\n",
+ " macro avg 0.77 0.77 0.77 312\n",
+ "weighted avg 0.77 0.77 0.77 312\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[128 34]\n",
+ " [ 39 111]]\n",
+ "\n",
+ "-AUC 0.7650617283950617\n",
+ "\n",
+ "Validation performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.74 0.80 0.77 54\n",
+ " 1 0.77 0.71 0.73 51\n",
+ "\n",
+ " accuracy 0.75 105\n",
+ " macro avg 0.75 0.75 0.75 105\n",
+ "weighted avg 0.75 0.75 0.75 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[43 11]\n",
+ " [15 36]]\n",
+ "\n",
+ "-AUC 0.7510893246187365\n",
+ "\n",
+ "Test performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.39 0.89 0.54 35\n",
+ " 1 0.90 0.42 0.58 85\n",
+ "\n",
+ " accuracy 0.56 120\n",
+ " macro avg 0.64 0.65 0.56 120\n",
+ "weighted avg 0.75 0.56 0.57 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[31 4]\n",
+ " [49 36]]\n",
+ "\n",
+ "-AUC 0.6546218487394958\n"
+ ]
+ }
+ ],
+ "source": [
+ "## simple logistic regression\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix\n",
+ "\n",
+ "#model = RandomForestClassifier() \n",
+ "model = LogisticRegression(penalty=\"none\", solver=\"lbfgs\")\n",
+ "\n",
+ "model.fit(X_train, y_train)\n",
+ "if hasattr(model, \"coef_\"):\n",
+ " coeffs = dict(zip(feat_names, model.coef_[0]))\n",
+ " print(coeffs)\n",
+ " print(\"\\nsignificant coeffs\", {k:v for k,v in coeffs.items() if np.abs(v) > .01})\n",
+ "print()\n",
+ "\n",
+ "def evaluate_model(model, X, y):\n",
+ " pred_y = model.predict(X)\n",
+ "\n",
+ " print(\"-Classification metrics\\n\", classification_report(y, pred_y))\n",
+ " print()\n",
+ " print(\"-Confusion matrix\\n\", confusion_matrix(y, pred_y))\n",
+ " print()\n",
+ " print(\"-AUC\", roc_auc_score(y, pred_y))\n",
+ " \n",
+ "\n",
+ "def make_prediction_df(model, X, patient_ids, method_name):\n",
+ " cls_1_index = np.where(model.classes_ == 1)[0][0]\n",
+ " pred_score = model.predict_proba(X)[:, cls_1_index]\n",
+ " return pd.DataFrame({\"PatientID\": patient_ids, f\"prediction_{method_name}\": pred_score})\n",
+ " \n",
+ " \n",
+ "print(\"Training performance\\n=====\")\n",
+ "evaluate_model(model, X_train, y_train)\n",
+ "\n",
+ "print(\"\\nValidation performance\\n=====\")\n",
+ "evaluate_model(model, X_valid, y_valid)\n",
+ "\n",
+ "print(\"\\nTest performance\\n=====\")\n",
+ "evaluate_model(model, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "3b488c20",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-logistic-regression-nopenalty-lbfgs-trainedOn312Patients\"\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a3c1d937",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "chemical-organ",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if hasattr(model, \"feature_importances_\"):\n",
+ " imp = dict(zip(feat_names, model.feature_importances_))\n",
+ " for k, v in imp.items():\n",
+ " print(k,v)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "scientific-cancer",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fitting 5 folds for each of 4 candidates, totalling 20 fits\n",
+ ": grid-search best parameters {'fit__class_weight': 'balanced', 'fit__penalty': 'l1'}\n",
+ ": grid-search best score 0.7469022017409115\n",
+ "\n",
+ "\n",
+ "Training performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.76 0.78 0.77 162\n",
+ " 1 0.76 0.73 0.74 150\n",
+ "\n",
+ " accuracy 0.76 312\n",
+ " macro avg 0.76 0.76 0.76 312\n",
+ "weighted avg 0.76 0.76 0.76 312\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[127 35]\n",
+ " [ 41 109]]\n",
+ "\n",
+ "-AUC 0.7553086419753087\n",
+ "\n",
+ "Validation performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.75 0.78 0.76 54\n",
+ " 1 0.76 0.73 0.74 51\n",
+ "\n",
+ " accuracy 0.75 105\n",
+ " macro avg 0.75 0.75 0.75 105\n",
+ "weighted avg 0.75 0.75 0.75 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[42 12]\n",
+ " [14 37]]\n",
+ "\n",
+ "-AUC 0.7516339869281046\n",
+ "\n",
+ "Test performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.38 0.91 0.53 35\n",
+ " 1 0.91 0.38 0.53 85\n",
+ "\n",
+ " accuracy 0.53 120\n",
+ " macro avg 0.65 0.65 0.53 120\n",
+ "weighted avg 0.76 0.53 0.53 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[32 3]\n",
+ " [53 32]]\n",
+ "\n",
+ "-AUC 0.6453781512605041\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.model_selection import GridSearchCV, KFold\n",
+ "\n",
+ "# l1 or l2\n",
+ "logistic_pipe = Pipeline([\n",
+ " #(\"scaling\", MinMaxScaler()),\n",
+ " (\"fit\", LogisticRegression(solver=\"liblinear\"))])\n",
+ "\n",
+ "logistic_grid = {\n",
+ " #\"scaling\": [\"passthrough\", MinMaxScaler(), StandardScaler()],\n",
+ " \"fit__penalty\": [\"l1\", \"l2\"],\n",
+ " \"fit__class_weight\": [None, \"balanced\"]}\n",
+ "\n",
+ "\n",
+ "searcher = GridSearchCV(\n",
+ " estimator=logistic_pipe,\n",
+ " param_grid=logistic_grid,\n",
+ " cv=KFold(n_splits=5, shuffle=False),\n",
+ " #scoring=\"neg_mean_absolute_error\",\n",
+ " verbose=1)\n",
+ "searcher.fit(X_train, y_train)\n",
+ "print(\": grid-search best parameters\", searcher.best_params_)\n",
+ "print(\": grid-search best score\", searcher.best_score_)\n",
+ "print()\n",
+ "\n",
+ "print(\"\\nTraining performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_train, y_train)\n",
+ "print(\"\\nValidation performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_valid, y_valid)\n",
+ "print(\"\\nTest performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "daf5b719",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients\"\n",
+ "model = searcher.best_estimator_\n",
+ "\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "3c3165bc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " PatientID | \n",
+ " prediction_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " P_102 | \n",
+ " 0.335200 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " P_117 | \n",
+ " 0.068277 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " P_16 | \n",
+ " 0.135848 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " P_118 | \n",
+ " 0.797524 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " P_114 | \n",
+ " 0.428291 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 115 | \n",
+ " P_88 | \n",
+ " 0.204731 | \n",
+ "
\n",
+ " \n",
+ " 116 | \n",
+ " P_92 | \n",
+ " 0.058076 | \n",
+ "
\n",
+ " \n",
+ " 117 | \n",
+ " P_86 | \n",
+ " 0.494830 | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " P_9 | \n",
+ " 0.677815 | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " P_90 | \n",
+ " 0.321286 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
120 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " PatientID \\\n",
+ "0 P_102 \n",
+ "1 P_117 \n",
+ "2 P_16 \n",
+ "3 P_118 \n",
+ "4 P_114 \n",
+ ".. ... \n",
+ "115 P_88 \n",
+ "116 P_92 \n",
+ "117 P_86 \n",
+ "118 P_9 \n",
+ "119 P_90 \n",
+ "\n",
+ " prediction_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients \n",
+ "0 0.335200 \n",
+ "1 0.068277 \n",
+ "2 0.135848 \n",
+ "3 0.797524 \n",
+ "4 0.428291 \n",
+ ".. ... \n",
+ "115 0.204731 \n",
+ "116 0.058076 \n",
+ "117 0.494830 \n",
+ "118 0.677815 \n",
+ "119 0.321286 \n",
+ "\n",
+ "[120 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pred_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "dominican-colonial",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fitting 5 folds for each of 400 candidates, totalling 2000 fits\n",
+ ": grid-search best parameters {'fit__max_depth': 5, 'fit__max_leaf_nodes': 20, 'fit__min_impurity_decrease': 0, 'fit__n_estimators': 10}\n",
+ ": grid-search best score 0.7596006144393241\n",
+ "\n",
+ "\n",
+ "Training performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.89 0.84 0.87 162\n",
+ " 1 0.84 0.89 0.86 150\n",
+ "\n",
+ " accuracy 0.87 312\n",
+ " macro avg 0.87 0.87 0.87 312\n",
+ "weighted avg 0.87 0.87 0.87 312\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[136 26]\n",
+ " [ 16 134]]\n",
+ "\n",
+ "-AUC 0.8664197530864197\n",
+ "\n",
+ "Validation performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.81 0.72 0.76 54\n",
+ " 1 0.74 0.82 0.78 51\n",
+ "\n",
+ " accuracy 0.77 105\n",
+ " macro avg 0.77 0.77 0.77 105\n",
+ "weighted avg 0.78 0.77 0.77 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[39 15]\n",
+ " [ 9 42]]\n",
+ "\n",
+ "-AUC 0.7728758169934641\n",
+ "\n",
+ "Test performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.41 0.89 0.56 35\n",
+ " 1 0.91 0.47 0.62 85\n",
+ "\n",
+ " accuracy 0.59 120\n",
+ " macro avg 0.66 0.68 0.59 120\n",
+ "weighted avg 0.76 0.59 0.60 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[31 4]\n",
+ " [45 40]]\n",
+ "\n",
+ "-AUC 0.6781512605042017\n"
+ ]
+ }
+ ],
+ "source": [
+ "rf_pipe = Pipeline([\n",
+ " #(\"scaling\", MinMaxScaler()),\n",
+ " (\"fit\", RandomForestClassifier(n_estimators=100,\n",
+ " # max_depth=5,\n",
+ " # max_leaf_nodes=20,\n",
+ " # min_impurity_decrease=.1\n",
+ " random_state=42,\n",
+ " ))\n",
+ "])\n",
+ "rf_grid = {\n",
+ " #\"scaling\": [MinMaxScaler(), StandardScaler()],\n",
+ " \"fit__n_estimators\": [5, 10, 20, 100, 200],\n",
+ " \"fit__max_depth\": [None, 5, 10, 20, 50],\n",
+ " \"fit__max_leaf_nodes\": [None, 20, 50, 100],\n",
+ " \"fit__min_impurity_decrease\": [0, 0.1, .5, 1] }\n",
+ "\n",
+ "searcher = GridSearchCV(\n",
+ " estimator=rf_pipe,\n",
+ " param_grid=rf_grid,\n",
+ " cv=KFold(n_splits=5, shuffle=False),\n",
+ " #scoring=\"neg_mean_absolute_error\",\n",
+ " verbose=1)\n",
+ "searcher.fit(X_train, y_train)\n",
+ "\n",
+ "print(\": grid-search best parameters\", searcher.best_params_)\n",
+ "print(\": grid-search best score\", searcher.best_score_)\n",
+ "print()\n",
+ "print(\"\\nTraining performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_train, y_train)\n",
+ "print(\"\\nValidation performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_valid, y_valid)\n",
+ "print(\"\\nTest performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "435a051d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-randomforest-10estimators-trainedOn312Patients\"\n",
+ "model = searcher.best_estimator_\n",
+ "\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "aca20c17",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/clinical_model/clinical_data_with_imputed_values.ipynb b/clinical_model/clinical_data_with_imputed_values.ipynb
new file mode 100644
index 0000000..e531bd0
--- /dev/null
+++ b/clinical_model/clinical_data_with_imputed_values.ipynb
@@ -0,0 +1,2319 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "electoral-korean",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import pathlib\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "9aedf0e5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0.0 | \n",
+ " 39.300000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 387.000000 | \n",
+ " 94.000000 | \n",
+ " 78.359259 | \n",
+ " 7.462222 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " MILD | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0.0 | \n",
+ " 37.000000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 338.000000 | \n",
+ " 94.000000 | \n",
+ " 75.000000 | \n",
+ " 7.420000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " MILD | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0.0 | \n",
+ " 37.800000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 356.000000 | \n",
+ " 94.000000 | \n",
+ " 63.000000 | \n",
+ " 7.390000 | \n",
+ " 1.0 | \n",
+ " -1.0 | \n",
+ " SEVERE | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0.0 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 482.000000 | \n",
+ " 97.000000 | \n",
+ " 68.000000 | \n",
+ " 7.460000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " SEVERE | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1.0 | \n",
+ " 37.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " 466.481481 | \n",
+ " 93.000000 | \n",
+ " 66.962963 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " MILD | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0.0 | \n",
+ " 37.533333 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 368.000000 | \n",
+ " 93.703704 | \n",
+ " 69.266667 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " SEVERE | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0.0 | \n",
+ " 37.274074 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 451.000000 | \n",
+ " 93.551852 | \n",
+ " 70.651852 | \n",
+ " 7.451852 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " SEVERE | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 37.659259 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " 212.000000 | \n",
+ " 96.814815 | \n",
+ " 83.962963 | \n",
+ " 7.430370 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " MILD | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1.0 | \n",
+ " 38.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.000000 | \n",
+ " 92.703704 | \n",
+ " 75.700000 | \n",
+ " 7.360000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " SEVERE | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0.0 | \n",
+ " 38.400000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.000000 | \n",
+ " 94.551852 | \n",
+ " 80.000000 | \n",
+ " 7.460000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " MILD | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 14 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0.0 39.300000 0.0 0.0 5.76 \n",
+ "P_132 57.266667 0.0 37.000000 1.0 0.0 11.48 \n",
+ "P_195 79.263889 0.0 37.800000 0.0 0.0 6.21 \n",
+ "P_193 82.000000 0.0 38.000000 0.0 0.0 7.28 \n",
+ "P_140 60.791667 1.0 37.000000 0.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0.0 37.533333 1.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0.0 37.274074 0.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0.0 37.659259 1.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1.0 38.000000 1.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0.0 38.400000 1.0 0.0 6.70 \n",
+ "\n",
+ " CRP LDH Ox_percentage PaO2 pH \\\n",
+ "PatientID \n",
+ "P_131 43.40 387.000000 94.000000 78.359259 7.462222 \n",
+ "P_132 64.00 338.000000 94.000000 75.000000 7.420000 \n",
+ "P_195 115.30 356.000000 94.000000 63.000000 7.390000 \n",
+ "P_193 149.30 482.000000 97.000000 68.000000 7.460000 \n",
+ "P_140 20.70 466.481481 93.000000 66.962963 7.452963 \n",
+ "... ... ... ... ... ... \n",
+ "P_1_12 22.79 368.000000 93.703704 69.266667 7.452963 \n",
+ "P_1_8 8.93 451.000000 93.551852 70.651852 7.451852 \n",
+ "P_1_10 0.23 212.000000 96.814815 83.962963 7.430370 \n",
+ "P_1_26 3.77 358.000000 92.703704 75.700000 7.360000 \n",
+ "P_1_146 0.02 149.000000 94.551852 80.000000 7.460000 \n",
+ "\n",
+ " CardiovascularDisease RespiratoryFailure Prognosis \n",
+ "PatientID \n",
+ "P_131 0.0 -1.0 MILD \n",
+ "P_132 0.0 -1.0 MILD \n",
+ "P_195 1.0 -1.0 SEVERE \n",
+ "P_193 0.0 -1.0 SEVERE \n",
+ "P_140 0.0 -1.0 MILD \n",
+ "... ... ... ... \n",
+ "P_1_12 0.0 0.0 SEVERE \n",
+ "P_1_8 0.0 0.0 SEVERE \n",
+ "P_1_10 0.0 0.0 MILD \n",
+ "P_1_26 0.0 0.0 SEVERE \n",
+ "P_1_146 1.0 0.0 MILD \n",
+ "\n",
+ "[863 rows x 14 columns]"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_path = pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/trainSet_filled2.txt\")\n",
+ "df_train = pd.read_csv(train_path)\n",
+ "df_train.set_index(\"PatientID\", inplace=True)\n",
+ "#df_train.drop([\"Unnamed: 0\"], axis=1, inplace=True)\n",
+ "\n",
+ "#df_old = pd.read_csv(pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/trainSet.txt\"))\n",
+ "#df_old.set_index(\"PatientID\", inplace=True)\n",
+ "#df_old = df_old[[\"Prognosis\"]]\n",
+ "\n",
+ "#df_train = pd.concat([df_train, df_old], axis=1)\n",
+ "\n",
+ "df_train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fc4bbb3d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "boolean-bacteria",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0.0 | \n",
+ " 39.300000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 387.000000 | \n",
+ " 94.000000 | \n",
+ " 78.359259 | \n",
+ " 7.462222 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0.0 | \n",
+ " 37.000000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 338.000000 | \n",
+ " 94.000000 | \n",
+ " 75.000000 | \n",
+ " 7.420000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0.0 | \n",
+ " 37.800000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 356.000000 | \n",
+ " 94.000000 | \n",
+ " 63.000000 | \n",
+ " 7.390000 | \n",
+ " 1.0 | \n",
+ " -1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0.0 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 482.000000 | \n",
+ " 97.000000 | \n",
+ " 68.000000 | \n",
+ " 7.460000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1.0 | \n",
+ " 37.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " 466.481481 | \n",
+ " 93.000000 | \n",
+ " 66.962963 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0.0 | \n",
+ " 37.533333 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 368.000000 | \n",
+ " 93.703704 | \n",
+ " 69.266667 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0.0 | \n",
+ " 37.274074 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 451.000000 | \n",
+ " 93.551852 | \n",
+ " 70.651852 | \n",
+ " 7.451852 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 37.659259 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " 212.000000 | \n",
+ " 96.814815 | \n",
+ " 83.962963 | \n",
+ " 7.430370 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1.0 | \n",
+ " 38.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.000000 | \n",
+ " 92.703704 | \n",
+ " 75.700000 | \n",
+ " 7.360000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0.0 | \n",
+ " 38.400000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.000000 | \n",
+ " 94.551852 | \n",
+ " 80.000000 | \n",
+ " 7.460000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 14 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0.0 39.300000 0.0 0.0 5.76 \n",
+ "P_132 57.266667 0.0 37.000000 1.0 0.0 11.48 \n",
+ "P_195 79.263889 0.0 37.800000 0.0 0.0 6.21 \n",
+ "P_193 82.000000 0.0 38.000000 0.0 0.0 7.28 \n",
+ "P_140 60.791667 1.0 37.000000 0.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0.0 37.533333 1.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0.0 37.274074 0.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0.0 37.659259 1.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1.0 38.000000 1.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0.0 38.400000 1.0 0.0 6.70 \n",
+ "\n",
+ " CRP LDH Ox_percentage PaO2 pH \\\n",
+ "PatientID \n",
+ "P_131 43.40 387.000000 94.000000 78.359259 7.462222 \n",
+ "P_132 64.00 338.000000 94.000000 75.000000 7.420000 \n",
+ "P_195 115.30 356.000000 94.000000 63.000000 7.390000 \n",
+ "P_193 149.30 482.000000 97.000000 68.000000 7.460000 \n",
+ "P_140 20.70 466.481481 93.000000 66.962963 7.452963 \n",
+ "... ... ... ... ... ... \n",
+ "P_1_12 22.79 368.000000 93.703704 69.266667 7.452963 \n",
+ "P_1_8 8.93 451.000000 93.551852 70.651852 7.451852 \n",
+ "P_1_10 0.23 212.000000 96.814815 83.962963 7.430370 \n",
+ "P_1_26 3.77 358.000000 92.703704 75.700000 7.360000 \n",
+ "P_1_146 0.02 149.000000 94.551852 80.000000 7.460000 \n",
+ "\n",
+ " CardiovascularDisease RespiratoryFailure Prognosis \n",
+ "PatientID \n",
+ "P_131 0.0 -1.0 0 \n",
+ "P_132 0.0 -1.0 0 \n",
+ "P_195 1.0 -1.0 1 \n",
+ "P_193 0.0 -1.0 1 \n",
+ "P_140 0.0 -1.0 0 \n",
+ "... ... ... ... \n",
+ "P_1_12 0.0 0.0 1 \n",
+ "P_1_8 0.0 0.0 1 \n",
+ "P_1_10 0.0 0.0 0 \n",
+ "P_1_26 0.0 0.0 1 \n",
+ "P_1_146 1.0 0.0 0 \n",
+ "\n",
+ "[863 rows x 14 columns]"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "#df_train.drop([\"ImageFile\"], axis=1, inplace=True)\n",
+ "df_train[\"Prognosis\"] = (df_train[\"Prognosis\"] == \"SEVERE\").astype(np.uint8)\n",
+ "\n",
+ "# one-hot encode the hospital\n",
+ "#foo = pd.get_dummies(df_train[\"Hospital\"], prefix=\"Hospital\")\n",
+ "#df_train = pd.concat([df_train, foo], axis=1)\n",
+ "#df_train = df_train.drop([\"Hospital\"], axis=1)\n",
+ "\n",
+ "\n",
+ "df_train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "fantastic-struggle",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(863, 14)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# this is how many complete samples we have\n",
+ "df_train.dropna().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18852aa5",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ongoing-dance",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Age \t\t\t 0 / 863\n",
+ "Sex \t\t\t 0 / 863\n",
+ "Temp_C \t\t\t 0 / 863\n",
+ "Cough \t\t\t 0 / 863\n",
+ "DifficultyInBreathing \t\t\t 0 / 863\n",
+ "WBC \t\t\t 0 / 863\n",
+ "CRP \t\t\t 0 / 863\n",
+ "LDH \t\t\t 0 / 863\n",
+ "Ox_percentage \t\t\t 0 / 863\n",
+ "PaO2 \t\t\t 0 / 863\n",
+ "pH \t\t\t 0 / 863\n",
+ "CardiovascularDisease \t\t\t 0 / 863\n",
+ "RespiratoryFailure \t\t\t 0 / 863\n",
+ "Prognosis \t\t\t 0 / 863\n",
+ "\n",
+ "Going to discard features with more than 200 nans: []\n"
+ ]
+ }
+ ],
+ "source": [
+ "# how many samples are missing for individual columns\n",
+ "discard = []\n",
+ "discard_thresh = 200\n",
+ "for c in df_train.columns:\n",
+ " n_missing = df_train[c].isna().sum()\n",
+ " print(c, \"\\t\\t\\t\", n_missing, \"/\", len(df_train))\n",
+ " if n_missing > discard_thresh:\n",
+ " discard.append(c)\n",
+ " \n",
+ "print(f\"\\nGoing to discard features with more than {discard_thresh} nans:\", discard)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "induced-treat",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0.0 | \n",
+ " 39.300000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 387.000000 | \n",
+ " 94.000000 | \n",
+ " 78.359259 | \n",
+ " 7.462222 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0.0 | \n",
+ " 37.000000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 338.000000 | \n",
+ " 94.000000 | \n",
+ " 75.000000 | \n",
+ " 7.420000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0.0 | \n",
+ " 37.800000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 356.000000 | \n",
+ " 94.000000 | \n",
+ " 63.000000 | \n",
+ " 7.390000 | \n",
+ " 1.0 | \n",
+ " -1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0.0 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 482.000000 | \n",
+ " 97.000000 | \n",
+ " 68.000000 | \n",
+ " 7.460000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1.0 | \n",
+ " 37.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " 466.481481 | \n",
+ " 93.000000 | \n",
+ " 66.962963 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0.0 | \n",
+ " 37.533333 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 368.000000 | \n",
+ " 93.703704 | \n",
+ " 69.266667 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0.0 | \n",
+ " 37.274074 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 451.000000 | \n",
+ " 93.551852 | \n",
+ " 70.651852 | \n",
+ " 7.451852 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 37.659259 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " 212.000000 | \n",
+ " 96.814815 | \n",
+ " 83.962963 | \n",
+ " 7.430370 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1.0 | \n",
+ " 38.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.000000 | \n",
+ " 92.703704 | \n",
+ " 75.700000 | \n",
+ " 7.360000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0.0 | \n",
+ " 38.400000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.000000 | \n",
+ " 94.551852 | \n",
+ " 80.000000 | \n",
+ " 7.460000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 14 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0.0 39.300000 0.0 0.0 5.76 \n",
+ "P_132 57.266667 0.0 37.000000 1.0 0.0 11.48 \n",
+ "P_195 79.263889 0.0 37.800000 0.0 0.0 6.21 \n",
+ "P_193 82.000000 0.0 38.000000 0.0 0.0 7.28 \n",
+ "P_140 60.791667 1.0 37.000000 0.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0.0 37.533333 1.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0.0 37.274074 0.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0.0 37.659259 1.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1.0 38.000000 1.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0.0 38.400000 1.0 0.0 6.70 \n",
+ "\n",
+ " CRP LDH Ox_percentage PaO2 pH \\\n",
+ "PatientID \n",
+ "P_131 43.40 387.000000 94.000000 78.359259 7.462222 \n",
+ "P_132 64.00 338.000000 94.000000 75.000000 7.420000 \n",
+ "P_195 115.30 356.000000 94.000000 63.000000 7.390000 \n",
+ "P_193 149.30 482.000000 97.000000 68.000000 7.460000 \n",
+ "P_140 20.70 466.481481 93.000000 66.962963 7.452963 \n",
+ "... ... ... ... ... ... \n",
+ "P_1_12 22.79 368.000000 93.703704 69.266667 7.452963 \n",
+ "P_1_8 8.93 451.000000 93.551852 70.651852 7.451852 \n",
+ "P_1_10 0.23 212.000000 96.814815 83.962963 7.430370 \n",
+ "P_1_26 3.77 358.000000 92.703704 75.700000 7.360000 \n",
+ "P_1_146 0.02 149.000000 94.551852 80.000000 7.460000 \n",
+ "\n",
+ " CardiovascularDisease RespiratoryFailure Prognosis \n",
+ "PatientID \n",
+ "P_131 0.0 -1.0 0 \n",
+ "P_132 0.0 -1.0 0 \n",
+ "P_195 1.0 -1.0 1 \n",
+ "P_193 0.0 -1.0 1 \n",
+ "P_140 0.0 -1.0 0 \n",
+ "... ... ... ... \n",
+ "P_1_12 0.0 0.0 1 \n",
+ "P_1_8 0.0 0.0 1 \n",
+ "P_1_10 0.0 0.0 0 \n",
+ "P_1_26 0.0 0.0 1 \n",
+ "P_1_146 1.0 0.0 0 \n",
+ "\n",
+ "[863 rows x 14 columns]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# since all values are imputed we should not have anything missing here anymore\n",
+ "df_reduced = df_train.drop(discard, axis=1)\n",
+ "df_reduced"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bizarre-college",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " count | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ " 863.000000 | \n",
+ "
\n",
+ " \n",
+ " mean | \n",
+ " 64.443784 | \n",
+ " 0.337196 | \n",
+ " 37.591352 | \n",
+ " 0.483198 | \n",
+ " 0.490151 | \n",
+ " 7.057532 | \n",
+ " 40.568481 | \n",
+ " 362.471459 | \n",
+ " 92.673928 | \n",
+ " 72.734359 | \n",
+ " 7.454625 | \n",
+ " 0.249131 | \n",
+ " -0.171495 | \n",
+ " 0.492468 | \n",
+ "
\n",
+ " \n",
+ " std | \n",
+ " 15.044203 | \n",
+ " 0.473027 | \n",
+ " 0.884709 | \n",
+ " 0.511477 | \n",
+ " 0.509386 | \n",
+ " 3.514950 | \n",
+ " 65.736724 | \n",
+ " 217.412325 | \n",
+ " 6.040142 | \n",
+ " 23.537867 | \n",
+ " 0.049681 | \n",
+ " 0.481004 | \n",
+ " 0.409599 | \n",
+ " 0.500233 | \n",
+ "
\n",
+ " \n",
+ " min | \n",
+ " 18.000000 | \n",
+ " 0.000000 | \n",
+ " 35.400000 | \n",
+ " -1.000000 | \n",
+ " -1.000000 | \n",
+ " 0.200000 | \n",
+ " 0.010000 | \n",
+ " 0.020000 | \n",
+ " 50.000000 | \n",
+ " 23.000000 | \n",
+ " 7.048000 | \n",
+ " -1.000000 | \n",
+ " -1.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 25% | \n",
+ " 54.000000 | \n",
+ " 0.000000 | \n",
+ " 37.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 4.660000 | \n",
+ " 5.970000 | \n",
+ " 251.000000 | \n",
+ " 91.153704 | \n",
+ " 61.900000 | \n",
+ " 7.437037 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 50% | \n",
+ " 65.000000 | \n",
+ " 0.000000 | \n",
+ " 37.603704 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 6.270000 | \n",
+ " 15.420000 | \n",
+ " 321.370370 | \n",
+ " 94.000000 | \n",
+ " 71.522222 | \n",
+ " 7.454741 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 75% | \n",
+ " 77.000000 | \n",
+ " 1.000000 | \n",
+ " 38.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 8.400000 | \n",
+ " 42.700000 | \n",
+ " 405.500000 | \n",
+ " 96.000000 | \n",
+ " 78.990741 | \n",
+ " 7.480000 | \n",
+ " 1.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ " max | \n",
+ " 97.000000 | \n",
+ " 1.000000 | \n",
+ " 40.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 27.460000 | \n",
+ " 570.500000 | \n",
+ " 2903.000000 | \n",
+ " 100.000000 | \n",
+ " 285.000000 | \n",
+ " 7.990000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing \\\n",
+ "count 863.000000 863.000000 863.000000 863.000000 863.000000 \n",
+ "mean 64.443784 0.337196 37.591352 0.483198 0.490151 \n",
+ "std 15.044203 0.473027 0.884709 0.511477 0.509386 \n",
+ "min 18.000000 0.000000 35.400000 -1.000000 -1.000000 \n",
+ "25% 54.000000 0.000000 37.000000 0.000000 0.000000 \n",
+ "50% 65.000000 0.000000 37.603704 0.000000 0.000000 \n",
+ "75% 77.000000 1.000000 38.000000 1.000000 1.000000 \n",
+ "max 97.000000 1.000000 40.000000 1.000000 1.000000 \n",
+ "\n",
+ " WBC CRP LDH Ox_percentage PaO2 \\\n",
+ "count 863.000000 863.000000 863.000000 863.000000 863.000000 \n",
+ "mean 7.057532 40.568481 362.471459 92.673928 72.734359 \n",
+ "std 3.514950 65.736724 217.412325 6.040142 23.537867 \n",
+ "min 0.200000 0.010000 0.020000 50.000000 23.000000 \n",
+ "25% 4.660000 5.970000 251.000000 91.153704 61.900000 \n",
+ "50% 6.270000 15.420000 321.370370 94.000000 71.522222 \n",
+ "75% 8.400000 42.700000 405.500000 96.000000 78.990741 \n",
+ "max 27.460000 570.500000 2903.000000 100.000000 285.000000 \n",
+ "\n",
+ " pH CardiovascularDisease RespiratoryFailure Prognosis \n",
+ "count 863.000000 863.000000 863.000000 863.000000 \n",
+ "mean 7.454625 0.249131 -0.171495 0.492468 \n",
+ "std 0.049681 0.481004 0.409599 0.500233 \n",
+ "min 7.048000 -1.000000 -1.000000 0.000000 \n",
+ "25% 7.437037 0.000000 0.000000 0.000000 \n",
+ "50% 7.454741 0.000000 0.000000 0.000000 \n",
+ "75% 7.480000 1.000000 0.000000 1.000000 \n",
+ "max 7.990000 1.000000 1.000000 1.000000 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_reduced = df_reduced.dropna()\n",
+ "df_reduced.describe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "compact-lawsuit",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Age | \n",
+ " Sex | \n",
+ " Temp_C | \n",
+ " Cough | \n",
+ " DifficultyInBreathing | \n",
+ " WBC | \n",
+ " CRP | \n",
+ " LDH | \n",
+ " Ox_percentage | \n",
+ " PaO2 | \n",
+ " pH | \n",
+ " CardiovascularDisease | \n",
+ " RespiratoryFailure | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_131 | \n",
+ " 35.913889 | \n",
+ " 0.0 | \n",
+ " 39.300000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.76 | \n",
+ " 43.40 | \n",
+ " 387.000000 | \n",
+ " 94.000000 | \n",
+ " 78.359259 | \n",
+ " 7.462222 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ "
\n",
+ " \n",
+ " P_132 | \n",
+ " 57.266667 | \n",
+ " 0.0 | \n",
+ " 37.000000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 11.48 | \n",
+ " 64.00 | \n",
+ " 338.000000 | \n",
+ " 94.000000 | \n",
+ " 75.000000 | \n",
+ " 7.420000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ "
\n",
+ " \n",
+ " P_195 | \n",
+ " 79.263889 | \n",
+ " 0.0 | \n",
+ " 37.800000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.21 | \n",
+ " 115.30 | \n",
+ " 356.000000 | \n",
+ " 94.000000 | \n",
+ " 63.000000 | \n",
+ " 7.390000 | \n",
+ " 1.0 | \n",
+ " -1.0 | \n",
+ "
\n",
+ " \n",
+ " P_193 | \n",
+ " 82.000000 | \n",
+ " 0.0 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 7.28 | \n",
+ " 149.30 | \n",
+ " 482.000000 | \n",
+ " 97.000000 | \n",
+ " 68.000000 | \n",
+ " 7.460000 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ "
\n",
+ " \n",
+ " P_140 | \n",
+ " 60.791667 | \n",
+ " 1.0 | \n",
+ " 37.000000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 6.37 | \n",
+ " 20.70 | \n",
+ " 466.481481 | \n",
+ " 93.000000 | \n",
+ " 66.962963 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " -1.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_1_12 | \n",
+ " 51.000000 | \n",
+ " 0.0 | \n",
+ " 37.533333 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 14.30 | \n",
+ " 22.79 | \n",
+ " 368.000000 | \n",
+ " 93.703704 | \n",
+ " 69.266667 | \n",
+ " 7.452963 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_8 | \n",
+ " 57.000000 | \n",
+ " 0.0 | \n",
+ " 37.274074 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 5.10 | \n",
+ " 8.93 | \n",
+ " 451.000000 | \n",
+ " 93.551852 | \n",
+ " 70.651852 | \n",
+ " 7.451852 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_10 | \n",
+ " 38.000000 | \n",
+ " 0.0 | \n",
+ " 37.659259 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 7.30 | \n",
+ " 0.23 | \n",
+ " 212.000000 | \n",
+ " 96.814815 | \n",
+ " 83.962963 | \n",
+ " 7.430370 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_26 | \n",
+ " 92.000000 | \n",
+ " 1.0 | \n",
+ " 38.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 13.50 | \n",
+ " 3.77 | \n",
+ " 358.000000 | \n",
+ " 92.703704 | \n",
+ " 75.700000 | \n",
+ " 7.360000 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " P_1_146 | \n",
+ " 80.000000 | \n",
+ " 0.0 | \n",
+ " 38.400000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 6.70 | \n",
+ " 0.02 | \n",
+ " 149.000000 | \n",
+ " 94.551852 | \n",
+ " 80.000000 | \n",
+ " 7.460000 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
863 rows × 13 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Age Sex Temp_C Cough DifficultyInBreathing WBC \\\n",
+ "PatientID \n",
+ "P_131 35.913889 0.0 39.300000 0.0 0.0 5.76 \n",
+ "P_132 57.266667 0.0 37.000000 1.0 0.0 11.48 \n",
+ "P_195 79.263889 0.0 37.800000 0.0 0.0 6.21 \n",
+ "P_193 82.000000 0.0 38.000000 0.0 0.0 7.28 \n",
+ "P_140 60.791667 1.0 37.000000 0.0 0.0 6.37 \n",
+ "... ... ... ... ... ... ... \n",
+ "P_1_12 51.000000 0.0 37.533333 1.0 1.0 14.30 \n",
+ "P_1_8 57.000000 0.0 37.274074 0.0 0.0 5.10 \n",
+ "P_1_10 38.000000 0.0 37.659259 1.0 1.0 7.30 \n",
+ "P_1_26 92.000000 1.0 38.000000 1.0 1.0 13.50 \n",
+ "P_1_146 80.000000 0.0 38.400000 1.0 0.0 6.70 \n",
+ "\n",
+ " CRP LDH Ox_percentage PaO2 pH \\\n",
+ "PatientID \n",
+ "P_131 43.40 387.000000 94.000000 78.359259 7.462222 \n",
+ "P_132 64.00 338.000000 94.000000 75.000000 7.420000 \n",
+ "P_195 115.30 356.000000 94.000000 63.000000 7.390000 \n",
+ "P_193 149.30 482.000000 97.000000 68.000000 7.460000 \n",
+ "P_140 20.70 466.481481 93.000000 66.962963 7.452963 \n",
+ "... ... ... ... ... ... \n",
+ "P_1_12 22.79 368.000000 93.703704 69.266667 7.452963 \n",
+ "P_1_8 8.93 451.000000 93.551852 70.651852 7.451852 \n",
+ "P_1_10 0.23 212.000000 96.814815 83.962963 7.430370 \n",
+ "P_1_26 3.77 358.000000 92.703704 75.700000 7.360000 \n",
+ "P_1_146 0.02 149.000000 94.551852 80.000000 7.460000 \n",
+ "\n",
+ " CardiovascularDisease RespiratoryFailure \n",
+ "PatientID \n",
+ "P_131 0.0 -1.0 \n",
+ "P_132 0.0 -1.0 \n",
+ "P_195 1.0 -1.0 \n",
+ "P_193 0.0 -1.0 \n",
+ "P_140 0.0 -1.0 \n",
+ "... ... ... \n",
+ "P_1_12 0.0 0.0 \n",
+ "P_1_8 0.0 0.0 \n",
+ "P_1_10 0.0 0.0 \n",
+ "P_1_26 0.0 0.0 \n",
+ "P_1_146 1.0 0.0 \n",
+ "\n",
+ "[863 rows x 13 columns]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "label_col = \"Prognosis\"\n",
+ "\n",
+ "y = df_reduced[[label_col]]\n",
+ "X_df = df_reduced[[c for c in df_reduced.columns if c not in [label_col]]]\n",
+ "\n",
+ "X_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "hollow-match",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "picked_features = X_df.columns\n",
+ "\n",
+ "#picked_features = [\n",
+ "# \"LDH\",\n",
+ "# \"Age\",\n",
+ "# \"PaO2\",\n",
+ "# \"CRP\",\n",
+ "# \"WBC\",\n",
+ "# \"pH\"\n",
+ "#]\n",
+ "\n",
+ "#\n",
+ "X = X_df[picked_features]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "38c4a8cc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# now also load the test data for which maksim imputed the values\n",
+ "\n",
+ "test_path = pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/testSet_filled2.txt\")\n",
+ "df_test = pd.read_csv(test_path)\n",
+ "df_test.set_index(\"PatientID\", inplace=True)\n",
+ "#df_test.drop([\"Unnamed: 0\"], axis=1, inplace=True)\n",
+ "# get rid of the features we dumped from the training set as well\n",
+ "df_test.drop([c for c in df_test.columns if c not in df_reduced.columns], axis=1, inplace=True)\n",
+ "# one-hot encode the hospital\n",
+ "#foo = pd.get_dummies(df_test[\"Hospital\"], prefix=\"Hospital\")\n",
+ "#df_test = pd.concat([df_test, foo], axis=1)\n",
+ "#df_test = df_test.drop([\"Hospital\"], axis=1)\n",
+ "assert sorted(df_test.columns) == sorted(X.columns)\n",
+ "\n",
+ "\n",
+ "X_test = df_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "7c989ecb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Prognosis | \n",
+ "
\n",
+ " \n",
+ " PatientID | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " P_102.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_117.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_16.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_118.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_114.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " P_88.png | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " P_92.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_86.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_9.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " P_90.png | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
120 rows × 1 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Prognosis\n",
+ "PatientID \n",
+ "P_102.png 1\n",
+ "P_117.png 0\n",
+ "P_16.png 1\n",
+ "P_118.png 0\n",
+ "P_114.png 1\n",
+ "... ...\n",
+ "P_88.png 0\n",
+ "P_92.png 1\n",
+ "P_86.png 1\n",
+ "P_9.png 1\n",
+ "P_90.png 1\n",
+ "\n",
+ "[120 rows x 1 columns]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_test = pd.read_csv(pathlib.Path(\"/home/starke88/data/haicu/covid_data_challenge_2021/solutionTestSet.txt\"))\n",
+ "y_test.set_index(\"ImageFile\", inplace=True)\n",
+ "y_test.index.rename(\"PatientID\", inplace=True)\n",
+ "y_test[\"Prognosis\"] = (y_test[\"Prognosis\"] == \"SEVERE\").astype(np.uint8)\n",
+ "y_test = y_test[[\"Prognosis\"]]\n",
+ "\n",
+ "y_test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "placed-ratio",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.preprocessing import MinMaxScaler, StandardScaler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "applicable-white",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['P_384', 'P_417', 'P_551', 'P_479', 'P_555', 'P_376', 'P_1_26', 'P_379', 'P_441', 'P_450', 'P_760', 'P_756', 'P_435', 'P_442', 'P_766', 'P_695', 'P_537', 'P_687', 'P_467', 'P_1_126', 'P_1_75', 'P_653', 'P_1_149', 'P_411', 'P_358', 'P_740', 'P_1_81', 'P_634', 'P_616', 'P_382', 'P_342', 'P_519', 'P_476', 'P_639', 'P_759', 'P_350', 'P_484', 'P_635', 'P_507', 'P_439', 'P_798', 'P_598', 'P_414', 'P_583', 'P_590', 'P_485', 'P_511', 'P_426', 'P_304', 'P_842', 'P_797', 'P_837', 'P_314', 'P_572', 'P_641', 'P_803', 'P_1_55', 'P_702', 'P_691', 'P_518', 'P_1_95', 'P_521', 'P_589', 'P_840', 'P_1_131', 'P_772', 'P_368', 'P_292', 'P_750', 'P_274', 'P_1_146', 'P_650', 'P_505', 'P_733', 'P_623', 'P_697', 'P_620', 'P_1_140', 'P_609', 'P_808', 'P_526', 'P_383', 'P_595', 'P_587', 'P_580', 'P_502', 'P_1_22', 'P_324', 'P_1_72', 'P_725', 'P_336', 'P_491', 'P_781', 'P_407', 'P_546', 'P_645', 'P_425', 'P_735', 'P_445', 'P_812', 'P_385', 'P_1_66', 'P_743', 'P_560', 'P_633']\n",
+ "105\n",
+ "758\n"
+ ]
+ }
+ ],
+ "source": [
+ "valid = list(pd.read_csv(\"/home/starke88/git/covid_data_challenge/clinical_model/valid_ids.csv\", header=None).to_numpy().squeeze())\n",
+ "print(valid)\n",
+ "\n",
+ "print(len(valid))\n",
+ "\n",
+ "\n",
+ "# TODO: train patients are all the others\n",
+ "train = [i for i in X_df.index if i not in valid]\n",
+ "print(len(train))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "convertible-function",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(758, 13) (758, 1)\n",
+ "(105, 13) (105, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "mask_train = X.index.isin(train)\n",
+ "feat_names = X.columns\n",
+ "\n",
+ "X_train = X[mask_train]\n",
+ "y_train = y[mask_train]\n",
+ "\n",
+ "X_valid = X[~mask_train]\n",
+ "y_valid = y[~mask_train]\n",
+ "\n",
+ "print(X_train.shape, y_train.shape)\n",
+ "print(X_valid.shape, y_valid.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "posted-mailing",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Age float64\n",
+ "Sex float64\n",
+ "Temp_C float64\n",
+ "Cough float64\n",
+ "DifficultyInBreathing float64\n",
+ "WBC float64\n",
+ "CRP float64\n",
+ "LDH float64\n",
+ "Ox_percentage float64\n",
+ "PaO2 float64\n",
+ "pH float64\n",
+ "CardiovascularDisease float64\n",
+ "RespiratoryFailure float64\n",
+ "dtype: object"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train.dtypes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "altered-payday",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_patient_ids(df):\n",
+ " return df.index.to_numpy().squeeze().tolist()\n",
+ "\n",
+ "scaler = MinMaxScaler()\n",
+ "#scaler = StandardScaler()\n",
+ "\n",
+ "ids_train = get_patient_ids(X_train)\n",
+ "X_train = scaler.fit_transform(X_train.to_numpy())\n",
+ "y_train = y_train.to_numpy().squeeze()\n",
+ "\n",
+ "ids_valid = get_patient_ids(X_valid)\n",
+ "X_valid = scaler.transform(X_valid.to_numpy())\n",
+ "y_valid = y_valid.to_numpy().squeeze()\n",
+ "\n",
+ "ids_test = get_patient_ids(X_test)\n",
+ "X_test = scaler.transform(X_test.to_numpy())\n",
+ "y_test = y_test.to_numpy().squeeze()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "stuffed-connecticut",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(758, 13) (758,) 758\n",
+ "(105, 13) (105,) 105\n",
+ "(120, 13) (120,) 120\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(X_train.shape, y_train.shape, len(ids_train))\n",
+ "print(X_valid.shape, y_valid.shape, len(ids_valid))\n",
+ "print(X_test.shape, y_test.shape, len(ids_test))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "instant-seating",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'Age': 2.6797615096744147, 'Sex': -0.6752236739227908, 'Temp_C': 0.8531413177461858, 'Cough': -0.08022687494637566, 'DifficultyInBreathing': 0.9606509001177268, 'WBC': 0.9440718078248774, 'CRP': 4.435017116517157, 'LDH': 22.117162047581637, 'Ox_percentage': -4.870467925676456, 'PaO2': 1.5059465620389225, 'pH': -2.957910963613239, 'CardiovascularDisease': -0.0255222153020973, 'RespiratoryFailure': 2.067007555935972}\n",
+ "\n",
+ "significant coeffs {'Age': 2.6797615096744147, 'Sex': -0.6752236739227908, 'Temp_C': 0.8531413177461858, 'Cough': -0.08022687494637566, 'DifficultyInBreathing': 0.9606509001177268, 'WBC': 0.9440718078248774, 'CRP': 4.435017116517157, 'LDH': 22.117162047581637, 'Ox_percentage': -4.870467925676456, 'PaO2': 1.5059465620389225, 'pH': -2.957910963613239, 'CardiovascularDisease': -0.0255222153020973, 'RespiratoryFailure': 2.067007555935972}\n",
+ "\n",
+ "Training performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.78 0.80 0.79 384\n",
+ " 1 0.79 0.77 0.78 374\n",
+ "\n",
+ " accuracy 0.78 758\n",
+ " macro avg 0.78 0.78 0.78 758\n",
+ "weighted avg 0.78 0.78 0.78 758\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[306 78]\n",
+ " [ 86 288]]\n",
+ "\n",
+ "-AUC 0.7834642379679144\n",
+ "\n",
+ "Validation performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.83 0.80 0.81 54\n",
+ " 1 0.79 0.82 0.81 51\n",
+ "\n",
+ " accuracy 0.81 105\n",
+ " macro avg 0.81 0.81 0.81 105\n",
+ "weighted avg 0.81 0.81 0.81 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[43 11]\n",
+ " [ 9 42]]\n",
+ "\n",
+ "-AUC 0.8099128540305011\n",
+ "\n",
+ "Test performance\n",
+ "=====\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.41 0.91 0.57 35\n",
+ " 1 0.93 0.46 0.61 85\n",
+ "\n",
+ " accuracy 0.59 120\n",
+ " macro avg 0.67 0.69 0.59 120\n",
+ "weighted avg 0.78 0.59 0.60 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[32 3]\n",
+ " [46 39]]\n",
+ "\n",
+ "-AUC 0.6865546218487394\n"
+ ]
+ }
+ ],
+ "source": [
+ "## simple logistic regression\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix\n",
+ "\n",
+ "#model = RandomForestClassifier() \n",
+ "model = LogisticRegression(penalty=\"none\", solver=\"lbfgs\")\n",
+ "\n",
+ "model.fit(X_train, y_train)\n",
+ "if hasattr(model, \"coef_\"):\n",
+ " coeffs = dict(zip(feat_names, model.coef_[0]))\n",
+ " print(coeffs)\n",
+ " print(\"\\nsignificant coeffs\", {k:v for k,v in coeffs.items() if np.abs(v) > .01})\n",
+ "print()\n",
+ "\n",
+ "def evaluate_model(model, X, y):\n",
+ " pred_y = model.predict(X)\n",
+ "\n",
+ " print(\"-Classification metrics\\n\", classification_report(y, pred_y))\n",
+ " print()\n",
+ " print(\"-Confusion matrix\\n\", confusion_matrix(y, pred_y))\n",
+ " print()\n",
+ " print(\"-AUC\", roc_auc_score(y, pred_y))\n",
+ " \n",
+ "\n",
+ "def make_prediction_df(model, X, patient_ids, method_name):\n",
+ " cls_1_index = np.where(model.classes_ == 1)[0][0]\n",
+ " pred_score = model.predict_proba(X)[:, cls_1_index]\n",
+ " return pd.DataFrame({\"PatientID\": patient_ids, f\"prediction_{method_name}\": pred_score})\n",
+ " \n",
+ " \n",
+ "print(\"Training performance\\n=====\")\n",
+ "evaluate_model(model, X_train, y_train)\n",
+ "\n",
+ "print(\"\\nValidation performance\\n=====\")\n",
+ "evaluate_model(model, X_valid, y_valid)\n",
+ "\n",
+ "print(\"\\nTest performance\\n=====\")\n",
+ "evaluate_model(model, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "a7ceb785",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-logistic-regression-nopenalty-lbfgs-trainedOn758PatientsWithImputation\"\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "chemical-organ",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if hasattr(model, \"feature_importances_\"):\n",
+ " imp = dict(zip(feat_names, model.feature_importances_))\n",
+ " for k, v in imp.items():\n",
+ " print(k,v)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "scientific-cancer",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fitting 5 folds for each of 4 candidates, totalling 20 fits\n",
+ ": grid-search best parameters {'fit__class_weight': 'balanced', 'fit__penalty': 'l1'}\n",
+ ": grid-search best score 0.7559515510630882\n",
+ "\n",
+ "\n",
+ "Training performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.76 0.78 0.77 384\n",
+ " 1 0.77 0.75 0.76 374\n",
+ "\n",
+ " accuracy 0.77 758\n",
+ " macro avg 0.77 0.77 0.77 758\n",
+ "weighted avg 0.77 0.77 0.77 758\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[300 84]\n",
+ " [ 93 281]]\n",
+ "\n",
+ "-AUC 0.766293449197861\n",
+ "\n",
+ "Validation performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.83 0.80 0.81 54\n",
+ " 1 0.79 0.82 0.81 51\n",
+ "\n",
+ " accuracy 0.81 105\n",
+ " macro avg 0.81 0.81 0.81 105\n",
+ "weighted avg 0.81 0.81 0.81 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[43 11]\n",
+ " [ 9 42]]\n",
+ "\n",
+ "-AUC 0.8099128540305011\n",
+ "\n",
+ "Test performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.42 0.91 0.57 35\n",
+ " 1 0.93 0.47 0.62 85\n",
+ "\n",
+ " accuracy 0.60 120\n",
+ " macro avg 0.67 0.69 0.60 120\n",
+ "weighted avg 0.78 0.60 0.61 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[32 3]\n",
+ " [45 40]]\n",
+ "\n",
+ "-AUC 0.6924369747899161\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.model_selection import GridSearchCV, KFold\n",
+ "\n",
+ "# l1 or l2\n",
+ "logistic_pipe = Pipeline([\n",
+ " #(\"scaling\", MinMaxScaler()),\n",
+ " (\"fit\", LogisticRegression(solver=\"liblinear\"))])\n",
+ "\n",
+ "logistic_grid = {\n",
+ " #\"scaling\": [\"passthrough\", MinMaxScaler(), StandardScaler()],\n",
+ " \"fit__penalty\": [\"l1\", \"l2\"],\n",
+ " \"fit__class_weight\": [None, \"balanced\"]}\n",
+ "\n",
+ "\n",
+ "searcher = GridSearchCV(\n",
+ " estimator=logistic_pipe,\n",
+ " param_grid=logistic_grid,\n",
+ " cv=KFold(n_splits=5, shuffle=False),\n",
+ " #scoring=\"neg_mean_absolute_error\",\n",
+ " verbose=1)\n",
+ "searcher.fit(X_train, y_train)\n",
+ "print(\": grid-search best parameters\", searcher.best_params_)\n",
+ "print(\": grid-search best score\", searcher.best_score_)\n",
+ "print()\n",
+ "print(\"\\nTraining performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_train, y_train)\n",
+ "print(\"\\nValidation performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_valid, y_valid)\n",
+ "print(\"\\nTest performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "a499fb59",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn758PatientsWithImputation\"\n",
+ "model = searcher.best_estimator_\n",
+ "\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "dominican-colonial",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fitting 5 folds for each of 400 candidates, totalling 2000 fits\n",
+ ": grid-search best parameters {'fit__max_depth': 20, 'fit__max_leaf_nodes': 100, 'fit__min_impurity_decrease': 0, 'fit__n_estimators': 100}\n",
+ ": grid-search best score 0.7533809689787383\n",
+ "\n",
+ "\n",
+ "Training performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.99 0.99 0.99 384\n",
+ " 1 0.99 0.99 0.99 374\n",
+ "\n",
+ " accuracy 0.99 758\n",
+ " macro avg 0.99 0.99 0.99 758\n",
+ "weighted avg 0.99 0.99 0.99 758\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[379 5]\n",
+ " [ 3 371]]\n",
+ "\n",
+ "-AUC 0.9894788881461675\n",
+ "\n",
+ "Validation performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.79 0.78 0.79 54\n",
+ " 1 0.77 0.78 0.78 51\n",
+ "\n",
+ " accuracy 0.78 105\n",
+ " macro avg 0.78 0.78 0.78 105\n",
+ "weighted avg 0.78 0.78 0.78 105\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[42 12]\n",
+ " [11 40]]\n",
+ "\n",
+ "-AUC 0.7810457516339868\n",
+ "\n",
+ "Test performance of best model\n",
+ "\n",
+ "-Classification metrics\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.44 0.91 0.59 35\n",
+ " 1 0.94 0.52 0.67 85\n",
+ "\n",
+ " accuracy 0.63 120\n",
+ " macro avg 0.69 0.72 0.63 120\n",
+ "weighted avg 0.79 0.63 0.65 120\n",
+ "\n",
+ "\n",
+ "-Confusion matrix\n",
+ " [[32 3]\n",
+ " [41 44]]\n",
+ "\n",
+ "-AUC 0.7159663865546217\n"
+ ]
+ }
+ ],
+ "source": [
+ "rf_pipe = Pipeline([\n",
+ " #(\"scaling\", MinMaxScaler()),\n",
+ " (\"fit\", RandomForestClassifier(n_estimators=100,\n",
+ " # max_depth=5,\n",
+ " # max_leaf_nodes=20,\n",
+ " # min_impurity_decrease=.1\n",
+ " random_state=42,\n",
+ " ))\n",
+ "])\n",
+ "rf_grid = {\n",
+ " #\"scaling\": [MinMaxScaler(), StandardScaler()],\n",
+ " \"fit__n_estimators\": [5, 10, 20, 100, 200],\n",
+ " \"fit__max_depth\": [None, 5, 10, 20, 50],\n",
+ " \"fit__max_leaf_nodes\": [None, 20, 50, 100],\n",
+ " \"fit__min_impurity_decrease\": [0, 0.1, .5, 1] }\n",
+ "\n",
+ "searcher = GridSearchCV(\n",
+ " estimator=rf_pipe,\n",
+ " param_grid=rf_grid,\n",
+ " cv=KFold(n_splits=5, shuffle=False),\n",
+ " #scoring=\"neg_mean_absolute_error\",\n",
+ " verbose=1)\n",
+ "searcher.fit(X_train, y_train)\n",
+ "\n",
+ "print(\": grid-search best parameters\", searcher.best_params_)\n",
+ "print(\": grid-search best score\", searcher.best_score_)\n",
+ "print()\n",
+ "print(\"\\nTraining performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_train, y_train)\n",
+ "print(\"\\nValidation performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_valid, y_valid)\n",
+ "print(\"\\nTest performance of best model\\n\")\n",
+ "evaluate_model(searcher.best_estimator_, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "taken-mortgage",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = \"Clinical-randomforest-100estimators-trainedOn758PatientsWithImputation\"\n",
+ "model = searcher.best_estimator_\n",
+ "\n",
+ "pred_train = make_prediction_df(model, X_train, patient_ids=ids_train, \n",
+ " method_name=model_name)\n",
+ "\n",
+ "pred_valid = make_prediction_df(model, X_valid, patient_ids=ids_valid, \n",
+ " method_name=model_name)\n",
+ "pred_test = make_prediction_df(model, X_test, patient_ids=ids_test, \n",
+ " method_name=model_name)\n",
+ "pred_test.to_csv(f\"/home/starke88/git/covid_data_challenge/clinical_model/predictions_{model_name}.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2163966",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients.csv b/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients.csv
new file mode 100644
index 0000000..aeedc33
--- /dev/null
+++ b/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn312Patients
+P_102,0.33520004647784835
+P_117,0.06827689247042472
+P_16,0.13584828184647155
+P_118,0.7975236836311405
+P_114,0.4282911929313953
+P_2,0.5607635032130464
+P_20,0.8791060732905628
+P_26,0.37615970263923376
+P_38,0.17539619031427528
+P_41,0.5662134515078459
+P_49,0.1607696261695052
+P_45,0.39998307827296087
+P_40,0.6567350933745325
+P_55,0.15078852023168587
+P_58,0.6936727399468227
+P_53,0.07938684251078379
+P_54,0.18994299016175986
+P_65,0.2471092057157912
+P_68,0.21118541518519335
+P_7,0.07290098155043286
+P_75,0.2270635764561114
+P_112,0.5377720346883706
+P_116,0.19923650965029246
+P_14,0.5841029589439515
+P_17,0.764786127124441
+P_111,0.44908579731238213
+P_1,0.17901755433065009
+P_106,0.7724132167599
+P_113,0.45708141399737634
+P_104,0.3294183049598339
+P_21,0.41069691477025
+P_25,0.3512817849966773
+P_44,0.08860785830460742
+P_46,0.27606226883431606
+P_43,0.20652410772080362
+P_30,0.6120256368189827
+P_50,0.41241659393951874
+P_60,0.3891833520344404
+P_6,0.4580981480214153
+P_69,0.18005363844199188
+P_66,0.11480638012803182
+P_85,0.8581273726962804
+P_94,0.2278812627780469
+P_84,0.18295590526231506
+P_83,0.12304522933578058
+P_91,0.24543700816344285
+P_76,0.2198964059213189
+P_77,0.530927099559969
+P_81,0.4658834530602529
+P_82,0.8765728099098288
+P_93,0.2307890949678067
+P_98,0.2817343412587218
+P_96,0.3128442272142378
+P_95,0.420416680576292
+P_80,0.875917955810806
+P_87,0.8473339962284742
+P_11,0.12939075660289623
+P_115,0.17932515301714272
+P_18,0.3412197166013632
+P_12,0.4085514105511912
+P_119,0.31999979532889444
+P_15,0.3317820046672228
+P_100,0.442321181987111
+P_107,0.616294812707784
+P_10,0.9886770584328072
+P_101,0.7923609634624441
+P_108,0.08003254948363757
+P_110,0.8563301321803212
+P_24,0.665118287186855
+P_28,0.8138245654764317
+P_23,0.8063273894728695
+P_29,0.3245450002343199
+P_32,0.606421411801071
+P_37,0.6802529357544478
+P_34,0.7860112241050077
+P_35,0.626668308143247
+P_36,0.34647149057349647
+P_39,0.09921526941333601
+P_4,0.7419867041541264
+P_42,0.376626436080426
+P_63,0.5375326047740958
+P_61,0.18159398036610488
+P_62,0.4121356368846893
+P_59,0.4139631148649756
+P_70,0.32800228129546755
+P_72,0.363302970875322
+P_74,0.22693675422272985
+P_97,0.4503424328184856
+P_99,0.21009202275843925
+P_73,0.19939571075463042
+P_78,0.18419547543075346
+P_71,0.3254840991389562
+P_120,0.3456758805905774
+P_13,0.3002468351439
+P_22,0.5961214844455944
+P_105,0.7083471027674944
+P_109,0.6882021260968174
+P_103,0.4777795011847989
+P_19,0.25850371254408455
+P_27,0.26326323331807183
+P_3,0.34560859025267243
+P_33,0.12648624250858967
+P_31,0.7299946886551383
+P_47,0.46315094593381667
+P_48,0.20489716776284475
+P_57,0.16939842591837587
+P_5,0.2860990510662038
+P_51,0.8770743465243397
+P_56,0.19027433629860382
+P_52,0.08535363595664093
+P_64,0.4535418761626194
+P_67,0.37293363879865526
+P_79,0.3726676202504323
+P_8,0.3187365182964029
+P_89,0.3424301683759125
+P_88,0.20473068036960596
+P_92,0.05807647412335641
+P_86,0.4948297380918963
+P_9,0.6778152753642893
+P_90,0.32128617511339685
diff --git a/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn758PatientsWithImputation.csv b/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn758PatientsWithImputation.csv
new file mode 100644
index 0000000..359f7be
--- /dev/null
+++ b/clinical_model/predictions_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn758PatientsWithImputation.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-logistic-regression-gridsearchedParams-l1penalty-liblinear-trainedOn758PatientsWithImputation
+P_102,0.3859644918622136
+P_117,0.07284203230498752
+P_16,0.07596895381275044
+P_118,0.9274230083127741
+P_114,0.5142052481657664
+P_2,0.5203429898535421
+P_20,0.8000266412980864
+P_26,0.22480692845410072
+P_38,0.1822518446772093
+P_41,0.5716081459889104
+P_49,0.2473877082260779
+P_45,0.16565900991729404
+P_40,0.7093868803545182
+P_55,0.1979382554661376
+P_58,0.9370410730987137
+P_53,0.08148851451642634
+P_54,0.08294504685675824
+P_65,0.37617956013175763
+P_68,0.35082219144166366
+P_7,0.12534370993786753
+P_75,0.19789852978823735
+P_112,0.6707704472783014
+P_116,0.19805373425310113
+P_14,0.6263954320158809
+P_17,0.5513416542968695
+P_111,0.428343994858313
+P_1,0.13438014802808462
+P_106,0.7783368402394052
+P_113,0.31689459434644496
+P_104,0.24976073218044986
+P_21,0.23995234648296687
+P_25,0.22963990332329398
+P_44,0.09021678959182249
+P_46,0.3781880682643061
+P_43,0.3265208524727368
+P_30,0.6997096980573042
+P_50,0.4426652399472462
+P_60,0.3207795092231149
+P_6,0.1934011763186107
+P_69,0.13343361173621096
+P_66,0.0740632399118832
+P_85,0.8854477880159981
+P_94,0.23190371268296364
+P_84,0.1955168690975254
+P_83,0.23803926535577632
+P_91,0.23848790972929118
+P_76,0.2843177940396236
+P_77,0.8673916935008623
+P_81,0.4004369413550231
+P_82,0.8742958150342303
+P_93,0.08080227666418292
+P_98,0.467161058965024
+P_96,0.23565533294391522
+P_95,0.5445814366451247
+P_80,0.8282874671584735
+P_87,0.9555111563767205
+P_11,0.059184061740297486
+P_115,0.3225233207342592
+P_18,0.2550110459899284
+P_12,0.623587242858893
+P_119,0.49512849763874905
+P_15,0.5608572003946439
+P_100,0.21366138277655405
+P_107,0.7598771587425981
+P_10,0.9919207015221694
+P_101,0.7063801922502047
+P_108,0.18363395680077926
+P_110,0.9421521314646447
+P_24,0.7365033803070948
+P_28,0.7818310300697472
+P_23,0.6808312092485941
+P_29,0.6400896202911683
+P_32,0.44742023295516
+P_37,0.6946437690664641
+P_34,0.7681504496611707
+P_35,0.600524125260131
+P_36,0.3665075968331374
+P_39,0.17383317531486894
+P_4,0.8587062403466065
+P_42,0.3232414000393499
+P_63,0.7475300485626806
+P_61,0.1767042091824176
+P_62,0.46973162577547806
+P_59,0.43752823467747265
+P_70,0.19947511560093817
+P_72,0.3100108707372959
+P_74,0.1443317377442051
+P_97,0.295647292893258
+P_99,0.2170817253904323
+P_73,0.20278616364162716
+P_78,0.18479805444077033
+P_71,0.581513813023997
+P_120,0.34810997392105003
+P_13,0.5975978103825967
+P_22,0.5369159972769143
+P_105,0.63716075585888
+P_109,0.6121339462956245
+P_103,0.30342540538527224
+P_19,0.10225471287027968
+P_27,0.1726493557480162
+P_3,0.2404355338123172
+P_33,0.31041182111727883
+P_31,0.835361706924908
+P_47,0.3810059842730029
+P_48,0.27986321930047503
+P_57,0.46219508350777755
+P_5,0.08958116840252862
+P_51,0.9275899520989356
+P_56,0.13710198371029936
+P_52,0.07490890194571004
+P_64,0.569648866590487
+P_67,0.4214220485937342
+P_79,0.2721423316936425
+P_8,0.3092079108946637
+P_89,0.26971439193936825
+P_88,0.20075491283355626
+P_92,0.07237953993542133
+P_86,0.5736632700761363
+P_9,0.6491348300524976
+P_90,0.28817196014750335
diff --git a/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn312Patients.csv b/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn312Patients.csv
new file mode 100644
index 0000000..386810a
--- /dev/null
+++ b/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn312Patients.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn312Patients
+P_102,0.24903656599446405
+P_117,0.04003292783510478
+P_16,0.12446757810452885
+P_118,0.9218160374499078
+P_114,0.4286919261325882
+P_2,0.6640975112163694
+P_20,0.9701462449738542
+P_26,0.40350849858638743
+P_38,0.1208567522913325
+P_41,0.7725482931492166
+P_49,0.07875418683872765
+P_45,0.4343138006325572
+P_40,0.7276851679444762
+P_55,0.15253447403064463
+P_58,0.8784259011286092
+P_53,0.025821151017420253
+P_54,0.23949971971236725
+P_65,0.12741975197440444
+P_68,0.20367678616280735
+P_7,0.02510497766076813
+P_75,0.10381775699916347
+P_112,0.5862375095444788
+P_116,0.1541375635924125
+P_14,0.7447034500619543
+P_17,0.8043590491041641
+P_111,0.3928752589840098
+P_1,0.13477245574194524
+P_106,0.8036419967131622
+P_113,0.2999517103060956
+P_104,0.3183387823212139
+P_21,0.5734787382187158
+P_25,0.24574649024284842
+P_44,0.0370316065900205
+P_46,0.22916173088150432
+P_43,0.12339156681362115
+P_30,0.6928260601737747
+P_50,0.32724604748247815
+P_60,0.48186534037452167
+P_6,0.5095930829252183
+P_69,0.1421954988590595
+P_66,0.08087919589605003
+P_85,0.9022399978238933
+P_94,0.10188391243468453
+P_84,0.24780992174710426
+P_83,0.07991659330774005
+P_91,0.2181067932193218
+P_76,0.12944253336885472
+P_77,0.41860703749776523
+P_81,0.4749830682794505
+P_82,0.9321785218637345
+P_93,0.23557437027348954
+P_98,0.369719152762019
+P_96,0.25844036130117654
+P_95,0.28040712453966216
+P_80,0.9771688451492363
+P_87,0.8363211321161775
+P_11,0.12495360676520206
+P_115,0.10173593582734358
+P_18,0.22990816675523057
+P_12,0.38342801146690786
+P_119,0.3162849081476485
+P_15,0.2377489248825254
+P_100,0.5675477509654552
+P_107,0.7246437743095745
+P_10,0.9994069351568323
+P_101,0.8971340337656126
+P_108,0.04077080591014789
+P_110,0.8751943820848411
+P_24,0.8781596323416284
+P_28,0.9644450441403449
+P_23,0.9309188166665078
+P_29,0.3510101876751694
+P_32,0.7051648055228926
+P_37,0.6414231072570596
+P_34,0.882659254222237
+P_35,0.7156817381482089
+P_36,0.28832412541491176
+P_39,0.030598088898500145
+P_4,0.8263203337126278
+P_42,0.297984514254678
+P_63,0.6694741987477717
+P_61,0.09583497157167596
+P_62,0.29356760542853205
+P_59,0.3500394829880046
+P_70,0.2890625180939771
+P_72,0.4674465959129291
+P_74,0.15504646053445142
+P_97,0.43165048417498425
+P_99,0.11431370256706168
+P_73,0.11272392805569711
+P_78,0.11492311015642452
+P_71,0.21199110034589025
+P_120,0.27036557823277135
+P_13,0.21381100516035592
+P_22,0.5516303971242341
+P_105,0.8390841609161274
+P_109,0.7820717005086841
+P_103,0.41686027760095956
+P_19,0.2729547667960578
+P_27,0.2537375276543555
+P_3,0.2738278688121344
+P_33,0.07477187752443443
+P_31,0.8448419680451451
+P_47,0.43146175916292073
+P_48,0.15466418064258186
+P_57,0.11060379553685333
+P_5,0.3522418049485151
+P_51,0.9279976991287591
+P_56,0.13622296469253095
+P_52,0.054660053928964995
+P_64,0.7182330795665481
+P_67,0.2767831256918614
+P_79,0.5145329801340882
+P_8,0.2647806566559488
+P_89,0.3273118637511314
+P_88,0.09156193298298344
+P_92,0.02644416054907079
+P_86,0.5000568945321757
+P_9,0.8194760027662711
+P_90,0.2293672292636861
diff --git a/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn758PatientsWithImputation.csv b/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn758PatientsWithImputation.csv
new file mode 100644
index 0000000..6572a5d
--- /dev/null
+++ b/clinical_model/predictions_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn758PatientsWithImputation.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-logistic-regression-nopenalty-lbfgs-trainedOn758PatientsWithImputation
+P_102,0.31835132661122767
+P_117,0.031654828816397126
+P_16,0.03742451771059102
+P_118,0.9263724726367898
+P_114,0.4487215180364961
+P_2,0.41209183762170254
+P_20,0.9413581709090953
+P_26,0.12254455868162095
+P_38,0.13686780719998914
+P_41,0.5546815828055516
+P_49,0.1853769046994918
+P_45,0.10769026257611337
+P_40,0.6935879208992427
+P_55,0.12311850411956264
+P_58,0.901061569837301
+P_53,0.05130167506402003
+P_54,0.038654491665258334
+P_65,0.2787366639095199
+P_68,0.2835908781423644
+P_7,0.06928335619636614
+P_75,0.18594873175266755
+P_112,0.6715641643206077
+P_116,0.16121286586660563
+P_14,0.6837813965299712
+P_17,0.6570071771667383
+P_111,0.47371976332323207
+P_1,0.07678311041218151
+P_106,0.8314763317530923
+P_113,0.3128436911901391
+P_104,0.2676081896898872
+P_21,0.1782253660252437
+P_25,0.20243887091158302
+P_44,0.059795204792418614
+P_46,0.30253807654342935
+P_43,0.3170826315702827
+P_30,0.7659238004605401
+P_50,0.453595275989049
+P_60,0.18855187541682203
+P_6,0.13359181337363316
+P_69,0.07185082039443393
+P_66,0.036344064257934225
+P_85,0.9014752010278936
+P_94,0.19104165975823614
+P_84,0.13230119920477593
+P_83,0.16897721157353496
+P_91,0.22527919584513
+P_76,0.21418197848603898
+P_77,0.8470080841044276
+P_81,0.40078826191809735
+P_82,0.9384661016546725
+P_93,0.04338501174989692
+P_98,0.4292495039269186
+P_96,0.18091404535537223
+P_95,0.51773101630826
+P_80,0.8182041515561647
+P_87,0.9735490678089487
+P_11,0.02846983665068818
+P_115,0.30101989295006626
+P_18,0.20607772228612625
+P_12,0.6211695827540649
+P_119,0.561789957082007
+P_15,0.5594166920975943
+P_100,0.15215523081216822
+P_107,0.7683849853685651
+P_10,0.9990551129397476
+P_101,0.6336624377345574
+P_108,0.10445505154862177
+P_110,0.9676982887548282
+P_24,0.6523241473862897
+P_28,0.8455845387060906
+P_23,0.6945626539230936
+P_29,0.6508184140286792
+P_32,0.5114257954822891
+P_37,0.7066189076650161
+P_34,0.8165771053854749
+P_35,0.6097645210880083
+P_36,0.30071317136038617
+P_39,0.1162753785938526
+P_4,0.8831712394278431
+P_42,0.3237176291753919
+P_63,0.7803444393399233
+P_61,0.13177536731575853
+P_62,0.416799561036709
+P_59,0.38225233800286906
+P_70,0.12265547315914713
+P_72,0.16934677202379259
+P_74,0.12062993425169195
+P_97,0.28002036014426435
+P_99,0.1679643155751617
+P_73,0.1611190550566634
+P_78,0.1664093982438191
+P_71,0.6085060575977724
+P_120,0.48033453626212197
+P_13,0.5701560862816882
+P_22,0.5621676251919606
+P_105,0.7336890101568924
+P_109,0.6593722999411834
+P_103,0.28491705072494317
+P_19,0.057279600804310396
+P_27,0.10588685917484085
+P_3,0.20766057194703635
+P_33,0.24522873540184817
+P_31,0.8179099011904133
+P_47,0.4913721860539052
+P_48,0.2331864252558713
+P_57,0.39122463485552156
+P_5,0.05547535275840547
+P_51,0.9426698156648277
+P_56,0.06723012997180255
+P_52,0.03595983576613419
+P_64,0.4828763552319358
+P_67,0.3941998337156697
+P_79,0.23583730755162813
+P_8,0.2713379872569081
+P_89,0.19709344495463335
+P_88,0.1265176302770373
+P_92,0.029902999142444903
+P_86,0.6033392456414456
+P_9,0.5807802704104573
+P_90,0.23251775468693256
diff --git a/clinical_model/predictions_Clinical-randomforest-100estimators-trainedOn758PatientsWithImputation.csv b/clinical_model/predictions_Clinical-randomforest-100estimators-trainedOn758PatientsWithImputation.csv
new file mode 100644
index 0000000..a1e19ab
--- /dev/null
+++ b/clinical_model/predictions_Clinical-randomforest-100estimators-trainedOn758PatientsWithImputation.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-randomforest-100estimators-trainedOn758PatientsWithImputation
+P_102,0.38946550336041985
+P_117,0.21282102831486885
+P_16,0.2292309049178195
+P_118,0.8573664001365144
+P_114,0.457689330355186
+P_2,0.8420358002655038
+P_20,0.7528432827774967
+P_26,0.24425810386334748
+P_38,0.29173251483568535
+P_41,0.656733411843052
+P_49,0.08799841020508709
+P_45,0.06698014714367952
+P_40,0.42709772892278974
+P_55,0.44534673653583484
+P_58,0.5900473936649999
+P_53,0.1530290965640159
+P_54,0.22204535598555325
+P_65,0.571362045176132
+P_68,0.29297598199300995
+P_7,0.2994842037318093
+P_75,0.14094637035751595
+P_112,0.5946715797653012
+P_116,0.07962602409320528
+P_14,0.7026038037522898
+P_17,0.6877395293713917
+P_111,0.384172711100478
+P_1,0.14695364889845955
+P_106,0.6628320711542883
+P_113,0.2468814460870966
+P_104,0.5301858443644191
+P_21,0.24849563982118067
+P_25,0.262311664929914
+P_44,0.027075693448032782
+P_46,0.31094975734390945
+P_43,0.6320940826036633
+P_30,0.8124279049116454
+P_50,0.2837401202915044
+P_60,0.6210785165034133
+P_6,0.1778964334173122
+P_69,0.02551757876760926
+P_66,0.01922233378480941
+P_85,0.7003444699466498
+P_94,0.10053982601769644
+P_84,0.4213463524866704
+P_83,0.4726997704752738
+P_91,0.09917730976213647
+P_76,0.20653495968118804
+P_77,0.582829309993914
+P_81,0.4174469791476917
+P_82,0.7558552935853973
+P_93,0.1558813883923038
+P_98,0.6312581558880244
+P_96,0.2833861091846772
+P_95,0.5094860156965737
+P_80,0.8803020997111964
+P_87,0.798605070433317
+P_11,0.2375383247517553
+P_115,0.4060607417634127
+P_18,0.15221828087832698
+P_12,0.6895586841901787
+P_119,0.7532326563388537
+P_15,0.6162961454854518
+P_100,0.2656280351221656
+P_107,0.6075562025213671
+P_10,0.819287313528472
+P_101,0.5418641730780982
+P_108,0.2065326710717261
+P_110,0.7644121052361565
+P_24,0.7197081996709395
+P_28,0.8517132591972909
+P_23,0.5829919166506827
+P_29,0.852096502982105
+P_32,0.4840101160178048
+P_37,0.5665506224238773
+P_34,0.5382771261139793
+P_35,0.4385345621302919
+P_36,0.05823659162874619
+P_39,0.02019208733288843
+P_4,0.6970987996312197
+P_42,0.2808675912614251
+P_63,0.7566663986230973
+P_61,0.2562584769528749
+P_62,0.5566724226360306
+P_59,0.3504794703875266
+P_70,0.15315224977807163
+P_72,0.20977103991222573
+P_74,0.135401898322289
+P_97,0.05458064834727395
+P_99,0.1466862165111942
+P_73,0.06753047604515297
+P_78,0.3309144651899263
+P_71,0.5088787113144024
+P_120,0.4531747487154755
+P_13,0.6939067244800701
+P_22,0.1438258555795674
+P_105,0.783879444593995
+P_109,0.5463574926544849
+P_103,0.08031589638700326
+P_19,0.1050073216613282
+P_27,0.3458076676894414
+P_3,0.1028003349570211
+P_33,0.4926338529597794
+P_31,0.650329965958054
+P_47,0.5183042333533406
+P_48,0.15527066117798724
+P_57,0.4430172205023816
+P_5,0.05170384323132989
+P_51,0.8645575025560717
+P_56,0.3186072782197584
+P_52,0.203838418651035
+P_64,0.6492725510950499
+P_67,0.07489726418238614
+P_79,0.20621292474964892
+P_8,0.19378498619689222
+P_89,0.34821110062882354
+P_88,0.19459208091786806
+P_92,0.33073754363143565
+P_86,0.6885462291774193
+P_9,0.8487928995809984
+P_90,0.21821745917040292
diff --git a/clinical_model/predictions_Clinical-randomforest-10estimators-trainedOn312Patients.csv b/clinical_model/predictions_Clinical-randomforest-10estimators-trainedOn312Patients.csv
new file mode 100644
index 0000000..216bd16
--- /dev/null
+++ b/clinical_model/predictions_Clinical-randomforest-10estimators-trainedOn312Patients.csv
@@ -0,0 +1,121 @@
+PatientID,prediction_Clinical-randomforest-10estimators-trainedOn312Patients
+P_102,0.6425278613155709
+P_117,0.23339452285392942
+P_16,0.382990132037891
+P_118,0.9282818035426731
+P_114,0.3877030197450563
+P_2,0.7176854742830476
+P_20,0.5874486027316216
+P_26,0.24590858248675715
+P_38,0.19759922660099263
+P_41,0.38828407284289634
+P_49,0.05482132198137636
+P_45,0.12953760353060625
+P_40,0.6554076846741177
+P_55,0.25681802424271527
+P_58,0.6599216988901248
+P_53,0.2728883144616123
+P_54,0.07001806553632886
+P_65,0.8024287344622024
+P_68,0.2804062441448522
+P_7,0.1863494735731443
+P_75,0.24436174441330313
+P_112,0.45450655406537754
+P_116,0.12089480074142227
+P_14,0.731673665061823
+P_17,0.757881631006631
+P_111,0.3625352904213991
+P_1,0.032433262279883815
+P_106,0.600067587054933
+P_113,0.25120229429085184
+P_104,0.5463005309217851
+P_21,0.33596179970040774
+P_25,0.4510787063430278
+P_44,0.032433262279883815
+P_46,0.15610908117428002
+P_43,0.4316392252026581
+P_30,0.5744508172139752
+P_50,0.2103876881394542
+P_60,0.6400735339845401
+P_6,0.11473607048947523
+P_69,0.032433262279883815
+P_66,0.032433262279883815
+P_85,0.49854260257562133
+P_94,0.20359240249774424
+P_84,0.6414821048847262
+P_83,0.2095265121315124
+P_91,0.2746567804301041
+P_76,0.05482132198137636
+P_77,0.8015773594690364
+P_81,0.42645251042634535
+P_82,0.7441847826086957
+P_93,0.06975935198789951
+P_98,0.7455578961017846
+P_96,0.26838604212465017
+P_95,0.2927685582052161
+P_80,0.7342954826321864
+P_87,0.9049493350852046
+P_11,0.14057321786538174
+P_115,0.14558276538480636
+P_18,0.08257265122063609
+P_12,0.9075995288836542
+P_119,0.6860976506630628
+P_15,0.6853565583352818
+P_100,0.42991953856728654
+P_107,0.7551234131021365
+P_10,0.7669739653621233
+P_101,0.8637369482792824
+P_108,0.10063518785885858
+P_110,0.8467355353331086
+P_24,0.7273715316660445
+P_28,0.9045257276507277
+P_23,0.5647038950715422
+P_29,0.5578085982873215
+P_32,0.47887474182724776
+P_37,0.6939094689094689
+P_34,0.657979461695531
+P_35,0.4935015308235123
+P_36,0.05482132198137636
+P_39,0.032433262279883815
+P_4,0.4665488814131521
+P_42,0.33675962351864125
+P_63,0.5857653713115479
+P_61,0.06534763777085004
+P_62,0.6061990228846782
+P_59,0.5056977911592361
+P_70,0.07228525227319468
+P_72,0.340420308960112
+P_74,0.05482132198137636
+P_97,0.05642300842432544
+P_99,0.12089480074142227
+P_73,0.12302324756035113
+P_78,0.5295393413775765
+P_71,0.4766596480201131
+P_120,0.6425225983739916
+P_13,0.6093354146721033
+P_22,0.3162071709183517
+P_105,0.6609344558913878
+P_109,0.1607578369964451
+P_103,0.29633565852841953
+P_19,0.32486360470515396
+P_27,0.155782582555422
+P_3,0.23361362904373592
+P_33,0.47015387670458153
+P_31,0.6582500280892166
+P_47,0.38842939314057395
+P_48,0.14912003634370707
+P_57,0.15576996142033422
+P_5,0.07881106812581798
+P_51,0.9388180967968202
+P_56,0.04989719257170215
+P_52,0.10063518785885858
+P_64,0.6925484526851948
+P_67,0.2087375332746368
+P_79,0.3102554658777222
+P_8,0.2283313033965022
+P_89,0.12390655119871508
+P_88,0.3149504891021572
+P_92,0.05176431096055968
+P_86,0.7091620551640344
+P_9,0.8156592620765963
+P_90,0.4152355895738249
diff --git a/clinical_model/valid_ids.csv b/clinical_model/valid_ids.csv
new file mode 100644
index 0000000..b4e05b1
--- /dev/null
+++ b/clinical_model/valid_ids.csv
@@ -0,0 +1,105 @@
+P_384
+P_417
+P_551
+P_479
+P_555
+P_376
+P_1_26
+P_379
+P_441
+P_450
+P_760
+P_756
+P_435
+P_442
+P_766
+P_695
+P_537
+P_687
+P_467
+P_1_126
+P_1_75
+P_653
+P_1_149
+P_411
+P_358
+P_740
+P_1_81
+P_634
+P_616
+P_382
+P_342
+P_519
+P_476
+P_639
+P_759
+P_350
+P_484
+P_635
+P_507
+P_439
+P_798
+P_598
+P_414
+P_583
+P_590
+P_485
+P_511
+P_426
+P_304
+P_842
+P_797
+P_837
+P_314
+P_572
+P_641
+P_803
+P_1_55
+P_702
+P_691
+P_518
+P_1_95
+P_521
+P_589
+P_840
+P_1_131
+P_772
+P_368
+P_292
+P_750
+P_274
+P_1_146
+P_650
+P_505
+P_733
+P_623
+P_697
+P_620
+P_1_140
+P_609
+P_808
+P_526
+P_383
+P_595
+P_587
+P_580
+P_502
+P_1_22
+P_324
+P_1_72
+P_725
+P_336
+P_491
+P_781
+P_407
+P_546
+P_645
+P_425
+P_735
+P_445
+P_812
+P_385
+P_1_66
+P_743
+P_560
+P_633