Skip to content

Commit 2ad5b87

Browse files
committed
Add bland-altman plots to evaluation
1 parent 16089e9 commit 2ad5b87

File tree

2 files changed

+141
-7
lines changed

2 files changed

+141
-7
lines changed

EvaluateModel.ipynb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,24 @@
324324
"generate_confusion_matrices(results_df, group_by=\"Sex\", save_path=\"outputs/actinet_vs_bbaa/by_sex.pdf\")\n",
325325
"generate_confusion_matrices(results_df, group_by=\"Age Band\", save_path=\"outputs/actinet_vs_bbaa/by_age.pdf\")"
326326
]
327+
},
328+
{
329+
"cell_type": "markdown",
330+
"metadata": {},
331+
"source": [
332+
"Bland-Altman plots"
333+
]
334+
},
335+
{
336+
"cell_type": "code",
337+
"execution_count": null,
338+
"metadata": {},
339+
"outputs": [],
340+
"source": [
341+
"generate_bland_altman_plots(results_df, save_path=\"outputs/actinet_vs_bbaa/bland_altman/full_population.pdf\")\n",
342+
"generate_bland_altman_plots(results_df, group_by=\"Sex\", save_path=\"outputs/actinet_vs_bbaa/bland_altman\")\n",
343+
"generate_bland_altman_plots(results_df, group_by=\"Age Band\", save_path=\"outputs/actinet_vs_bbaa/bland_altman\")"
344+
]
327345
}
328346
],
329347
"metadata": {

src/actinet/utils/eval_utils.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import seaborn as sns
33
import numpy as np
44
import pandas as pd
5+
import os
56
from scipy import stats
67
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, balanced_accuracy_score, confusion_matrix
78
import warnings
@@ -57,14 +58,17 @@ def build_confusion_matrix_data(results: pd.DataFrame, age_band=None, sex=None):
5758
y_true = np.hstack(model_results.loc[model_results["Model"]=="actinet", 'True'])
5859
y_pred_bbaa = np.hstack(model_results.loc[model_results["Model"]=="accelerometer", 'Predicted'])
5960
y_pred_actinet = np.hstack(model_results.loc[model_results["Model"]=="actinet", 'Predicted'])
61+
62+
population = len(model_results['Participant'].unique())
6063

61-
return y_true, y_pred_bbaa, y_pred_actinet
64+
return y_true, y_pred_bbaa, y_pred_actinet, population
6265

6366

6467
def plot_and_save_fig(fig, save_path=None):
6568
"""Displays and optionally saves the figure as a PDF."""
6669
plt.show()
6770
if save_path:
71+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
6872
fig.savefig(save_path, format='pdf', dpi=800, bbox_inches='tight')
6973

7074

@@ -75,7 +79,7 @@ def generate_confusion_matrices(results_df, group_by=None, save_path=None):
7579
fig.suptitle("Confusion matrices for full Capture-24 population using 5-fold group cross-validation",
7680
fontsize=16)
7781

78-
y_true, y_pred_bbaa, y_pred_actinet = build_confusion_matrix_data(results_df)
82+
y_true, y_pred_bbaa, y_pred_actinet, _ = build_confusion_matrix_data(results_df)
7983

8084
plot_confusion_matrix(y_true, y_pred_bbaa, 'accelerometer', ax=axs[0], fontsize=14)
8185
plot_confusion_matrix(y_true, y_pred_actinet, 'actinet', ax=axs[1], fontsize=14)
@@ -90,13 +94,13 @@ def generate_confusion_matrices(results_df, group_by=None, save_path=None):
9094

9195
subfigs = fig.subfigures(nrows=len(unique_groups), ncols=1)
9296

93-
for i, (group, subfig) in enumerate(zip(unique_groups, subfigs)):
94-
subfig.suptitle(f"{group_by}: {group}", fontsize=14)
95-
axs = subfig.subplots(nrows=1, ncols=2, sharey=True)
96-
97-
y_true, y_pred_bbaa, y_pred_actinet =\
97+
for group, subfig in zip(unique_groups, subfigs):
98+
y_true, y_pred_bbaa, y_pred_actinet, population =\
9899
build_confusion_matrix_data(results_df, **{group_by.replace(' ', '_').lower(): group})
99100

101+
subfig.suptitle(f"{group_by}: {group} (n = {population})", fontsize=14)
102+
axs = subfig.subplots(nrows=1, ncols=2, sharey=True)
103+
100104
plot_confusion_matrix(y_true, y_pred_bbaa, 'accelerometer', ax=axs[0], fontsize=14)
101105
plot_confusion_matrix(y_true, y_pred_actinet, 'actinet', ax=axs[1], fontsize=14)
102106

@@ -169,3 +173,115 @@ def plot_boxplots(df, x, y='Macro F1', hue='Model'):
169173
ax.set_ylabel("Macro F1 Score")
170174
plt.title(f"Macro F1 by {x}")
171175
plt.show()
176+
177+
178+
def build_bland_altman_data(results: pd.DataFrame, activity, age_band=None, sex=None):
179+
"""Extracts incidence of predicted activity label for actinet and accelerometer based on filtering conditions."""
180+
model_results = results.copy()
181+
if age_band is not None:
182+
model_results = model_results[model_results['Age Band'] == age_band]
183+
if sex is not None:
184+
model_results = model_results[model_results["Sex"] == sex]
185+
186+
activity_bbaa_pred = [x.get(activity, 0) for x in model_results.loc[model_results["Model"] == "accelerometer", "Pred_dict"]]
187+
activity_actinet_pred = [x.get(activity, 0) for x in model_results.loc[model_results["Model"] == "actinet", "Pred_dict"]]
188+
189+
if activity == 'moderate-vigorous': # Convert hours to minutes
190+
activity_bbaa_pred = [60*x for x in activity_bbaa_pred]
191+
activity_actinet_pred = [60*x for x in activity_actinet_pred]
192+
193+
population = len(model_results['Participant'].unique())
194+
195+
return activity_bbaa_pred, activity_actinet_pred, population
196+
197+
198+
def bland_altman_plot(col1, col2, plot_label: str, output_dir='',
199+
col1_label='accelerometer', col2_label='actinet',
200+
display_plot=False, show_y_label=False, ax=None,
201+
activity_type=None):
202+
"""Generates a Bland-Altman plot for two columns of data."""
203+
dat = pd.DataFrame({'col1': col1, 'col2': col2})
204+
pearson_cor = dat.corr().iloc[0, 1]
205+
diffs = dat['col1'] - dat['col2']
206+
mean_diff = np.mean(diffs)
207+
sd_diff = np.std(diffs, ddof=1)
208+
lower_loa = mean_diff - 1.96 * sd_diff
209+
upper_loa = mean_diff + 1.96 * sd_diff
210+
211+
mean_vals = (dat['col1'] + dat['col2']) / 2
212+
213+
if ax is None:
214+
_, ax = plt.subplots(figsize=(10, 10), dpi=800)
215+
216+
ax.scatter(mean_vals, diffs, color='black', alpha=1)
217+
ax.axhline(mean_diff, color='red', linestyle='-')
218+
ax.axhline(lower_loa, color='blue', linestyle='--')
219+
ax.axhline(upper_loa, color='blue', linestyle='--')
220+
ax.text(0.8 * max(mean_vals), mean_diff, f'Mean Diff = {mean_diff:.2f}', va='bottom', color='red')
221+
ax.text(0.8 * max(mean_vals), lower_loa, f'Lower LoA = {lower_loa:.2f}', va='bottom', color='blue')
222+
ax.text(0.8 * max(mean_vals), upper_loa, f'Upper LoA = {upper_loa:.2f}', va='bottom', color='blue')
223+
224+
if activity_type in ['sleep', 'sedentary', 'light']:
225+
unit_label = '[hours]'
226+
elif activity_type == 'moderate-vigorous':
227+
unit_label = '[minutes]'
228+
else:
229+
unit_label = ''
230+
231+
ax.set_title(f'{plot_label.capitalize()} Activity | Pearson correlation: {pearson_cor:.3f}')
232+
233+
ax.set_xlabel(f'({col1_label} + {col2_label}) / 2 {unit_label}')
234+
235+
if show_y_label:
236+
ax.set_ylabel(f'{col1_label} - {col2_label} {unit_label}')
237+
238+
ax.tick_params(axis='both', which='both', labelsize=14)
239+
240+
if display_plot and ax is None:
241+
plt.show()
242+
243+
if ax is None:
244+
os.makedirs(output_dir, exist_ok=True)
245+
plot_path = os.path.join(output_dir, f'ba_{plot_label}_{col1_label}_vs_{col2_label}.png')
246+
plt.savefig(plot_path, bbox_inches='tight')
247+
plt.close()
248+
249+
250+
def generate_bland_altman_plots(results_df, activities=['sleep', 'sedentary', 'light', 'moderate-vigorous'],
251+
group_by=None, save_path=None):
252+
"""Generates Bland-Altman plots for different activities, optionally stratified by a subgroup."""
253+
254+
if group_by is None: # Full population
255+
fig, axs = plt.subplots(2, 2, figsize=(15, 10), dpi=800, sharey=False)
256+
fig.suptitle("Bland-Altman plots for full Capture-24 population", fontsize=20)
257+
axs = axs.flatten()
258+
259+
for i, activity in enumerate(activities):
260+
activity_bbaa_pred, activity_actinet_pred, _ = build_bland_altman_data(results_df, activity)
261+
bland_altman_plot(activity_bbaa_pred, activity_actinet_pred, activity.capitalize(),
262+
ax=axs[i], show_y_label=True, activity_type=activity)
263+
264+
plot_and_save_fig(fig, save_path=save_path)
265+
266+
else:
267+
unique_groups = results_df[group_by].cat.categories
268+
269+
for group in unique_groups:
270+
fig, axs = plt.subplots(2, 2, figsize=(15, 10), dpi=800, sharey=False)
271+
axs = axs.flatten()
272+
273+
for i, activity in enumerate(activities):
274+
activity_bbaa_pred, activity_actinet_pred, population = build_bland_altman_data(
275+
results_df, activity, **{group_by.replace(' ', '_').lower(): group})
276+
bland_altman_plot(activity_bbaa_pred, activity_actinet_pred, activity.capitalize(),
277+
ax=axs[i], show_y_label=True, activity_type=activity)
278+
279+
fig.suptitle(f"Bland-Altman plots for different {group_by.lower()} in Capture-24 population\n" +
280+
f"{group_by}: {group} (n={population})", fontsize=20)
281+
282+
group_filename = f"{group_by.lower().replace(' ', '_')}_{group}.pdf"
283+
save_path_group = f"{save_path}/{group_filename}" if save_path else group_filename
284+
285+
plot_and_save_fig(fig, save_path=save_path_group)
286+
287+
plt.close(fig)

0 commit comments

Comments
 (0)