22import seaborn as sns
33import numpy as np
44import pandas as pd
5+ import os
56from scipy import stats
67from sklearn .metrics import accuracy_score , f1_score , cohen_kappa_score , balanced_accuracy_score , confusion_matrix
78import warnings
@@ -57,14 +58,17 @@ def build_confusion_matrix_data(results: pd.DataFrame, age_band=None, sex=None):
5758 y_true = np .hstack (model_results .loc [model_results ["Model" ]== "actinet" , 'True' ])
5859 y_pred_bbaa = np .hstack (model_results .loc [model_results ["Model" ]== "accelerometer" , 'Predicted' ])
5960 y_pred_actinet = np .hstack (model_results .loc [model_results ["Model" ]== "actinet" , 'Predicted' ])
61+
62+ population = len (model_results ['Participant' ].unique ())
6063
61- return y_true , y_pred_bbaa , y_pred_actinet
64+ return y_true , y_pred_bbaa , y_pred_actinet , population
6265
6366
6467def plot_and_save_fig (fig , save_path = None ):
6568 """Displays and optionally saves the figure as a PDF."""
6669 plt .show ()
6770 if save_path :
71+ os .makedirs (os .path .dirname (save_path ), exist_ok = True )
6872 fig .savefig (save_path , format = 'pdf' , dpi = 800 , bbox_inches = 'tight' )
6973
7074
@@ -75,7 +79,7 @@ def generate_confusion_matrices(results_df, group_by=None, save_path=None):
7579 fig .suptitle ("Confusion matrices for full Capture-24 population using 5-fold group cross-validation" ,
7680 fontsize = 16 )
7781
78- y_true , y_pred_bbaa , y_pred_actinet = build_confusion_matrix_data (results_df )
82+ y_true , y_pred_bbaa , y_pred_actinet , _ = build_confusion_matrix_data (results_df )
7983
8084 plot_confusion_matrix (y_true , y_pred_bbaa , 'accelerometer' , ax = axs [0 ], fontsize = 14 )
8185 plot_confusion_matrix (y_true , y_pred_actinet , 'actinet' , ax = axs [1 ], fontsize = 14 )
@@ -90,13 +94,13 @@ def generate_confusion_matrices(results_df, group_by=None, save_path=None):
9094
9195 subfigs = fig .subfigures (nrows = len (unique_groups ), ncols = 1 )
9296
93- for i , (group , subfig ) in enumerate (zip (unique_groups , subfigs )):
94- subfig .suptitle (f"{ group_by } : { group } " , fontsize = 14 )
95- axs = subfig .subplots (nrows = 1 , ncols = 2 , sharey = True )
96-
97- y_true , y_pred_bbaa , y_pred_actinet = \
97+ for group , subfig in zip (unique_groups , subfigs ):
98+ y_true , y_pred_bbaa , y_pred_actinet , population = \
9899 build_confusion_matrix_data (results_df , ** {group_by .replace (' ' , '_' ).lower (): group })
99100
101+ subfig .suptitle (f"{ group_by } : { group } (n = { population } )" , fontsize = 14 )
102+ axs = subfig .subplots (nrows = 1 , ncols = 2 , sharey = True )
103+
100104 plot_confusion_matrix (y_true , y_pred_bbaa , 'accelerometer' , ax = axs [0 ], fontsize = 14 )
101105 plot_confusion_matrix (y_true , y_pred_actinet , 'actinet' , ax = axs [1 ], fontsize = 14 )
102106
@@ -169,3 +173,115 @@ def plot_boxplots(df, x, y='Macro F1', hue='Model'):
169173 ax .set_ylabel ("Macro F1 Score" )
170174 plt .title (f"Macro F1 by { x } " )
171175 plt .show ()
176+
177+
178+ def build_bland_altman_data (results : pd .DataFrame , activity , age_band = None , sex = None ):
179+ """Extracts incidence of predicted activity label for actinet and accelerometer based on filtering conditions."""
180+ model_results = results .copy ()
181+ if age_band is not None :
182+ model_results = model_results [model_results ['Age Band' ] == age_band ]
183+ if sex is not None :
184+ model_results = model_results [model_results ["Sex" ] == sex ]
185+
186+ activity_bbaa_pred = [x .get (activity , 0 ) for x in model_results .loc [model_results ["Model" ] == "accelerometer" , "Pred_dict" ]]
187+ activity_actinet_pred = [x .get (activity , 0 ) for x in model_results .loc [model_results ["Model" ] == "actinet" , "Pred_dict" ]]
188+
189+ if activity == 'moderate-vigorous' : # Convert hours to minutes
190+ activity_bbaa_pred = [60 * x for x in activity_bbaa_pred ]
191+ activity_actinet_pred = [60 * x for x in activity_actinet_pred ]
192+
193+ population = len (model_results ['Participant' ].unique ())
194+
195+ return activity_bbaa_pred , activity_actinet_pred , population
196+
197+
198+ def bland_altman_plot (col1 , col2 , plot_label : str , output_dir = '' ,
199+ col1_label = 'accelerometer' , col2_label = 'actinet' ,
200+ display_plot = False , show_y_label = False , ax = None ,
201+ activity_type = None ):
202+ """Generates a Bland-Altman plot for two columns of data."""
203+ dat = pd .DataFrame ({'col1' : col1 , 'col2' : col2 })
204+ pearson_cor = dat .corr ().iloc [0 , 1 ]
205+ diffs = dat ['col1' ] - dat ['col2' ]
206+ mean_diff = np .mean (diffs )
207+ sd_diff = np .std (diffs , ddof = 1 )
208+ lower_loa = mean_diff - 1.96 * sd_diff
209+ upper_loa = mean_diff + 1.96 * sd_diff
210+
211+ mean_vals = (dat ['col1' ] + dat ['col2' ]) / 2
212+
213+ if ax is None :
214+ _ , ax = plt .subplots (figsize = (10 , 10 ), dpi = 800 )
215+
216+ ax .scatter (mean_vals , diffs , color = 'black' , alpha = 1 )
217+ ax .axhline (mean_diff , color = 'red' , linestyle = '-' )
218+ ax .axhline (lower_loa , color = 'blue' , linestyle = '--' )
219+ ax .axhline (upper_loa , color = 'blue' , linestyle = '--' )
220+ ax .text (0.8 * max (mean_vals ), mean_diff , f'Mean Diff = { mean_diff :.2f} ' , va = 'bottom' , color = 'red' )
221+ ax .text (0.8 * max (mean_vals ), lower_loa , f'Lower LoA = { lower_loa :.2f} ' , va = 'bottom' , color = 'blue' )
222+ ax .text (0.8 * max (mean_vals ), upper_loa , f'Upper LoA = { upper_loa :.2f} ' , va = 'bottom' , color = 'blue' )
223+
224+ if activity_type in ['sleep' , 'sedentary' , 'light' ]:
225+ unit_label = '[hours]'
226+ elif activity_type == 'moderate-vigorous' :
227+ unit_label = '[minutes]'
228+ else :
229+ unit_label = ''
230+
231+ ax .set_title (f'{ plot_label .capitalize ()} Activity | Pearson correlation: { pearson_cor :.3f} ' )
232+
233+ ax .set_xlabel (f'({ col1_label } + { col2_label } ) / 2 { unit_label } ' )
234+
235+ if show_y_label :
236+ ax .set_ylabel (f'{ col1_label } - { col2_label } { unit_label } ' )
237+
238+ ax .tick_params (axis = 'both' , which = 'both' , labelsize = 14 )
239+
240+ if display_plot and ax is None :
241+ plt .show ()
242+
243+ if ax is None :
244+ os .makedirs (output_dir , exist_ok = True )
245+ plot_path = os .path .join (output_dir , f'ba_{ plot_label } _{ col1_label } _vs_{ col2_label } .png' )
246+ plt .savefig (plot_path , bbox_inches = 'tight' )
247+ plt .close ()
248+
249+
250+ def generate_bland_altman_plots (results_df , activities = ['sleep' , 'sedentary' , 'light' , 'moderate-vigorous' ],
251+ group_by = None , save_path = None ):
252+ """Generates Bland-Altman plots for different activities, optionally stratified by a subgroup."""
253+
254+ if group_by is None : # Full population
255+ fig , axs = plt .subplots (2 , 2 , figsize = (15 , 10 ), dpi = 800 , sharey = False )
256+ fig .suptitle ("Bland-Altman plots for full Capture-24 population" , fontsize = 20 )
257+ axs = axs .flatten ()
258+
259+ for i , activity in enumerate (activities ):
260+ activity_bbaa_pred , activity_actinet_pred , _ = build_bland_altman_data (results_df , activity )
261+ bland_altman_plot (activity_bbaa_pred , activity_actinet_pred , activity .capitalize (),
262+ ax = axs [i ], show_y_label = True , activity_type = activity )
263+
264+ plot_and_save_fig (fig , save_path = save_path )
265+
266+ else :
267+ unique_groups = results_df [group_by ].cat .categories
268+
269+ for group in unique_groups :
270+ fig , axs = plt .subplots (2 , 2 , figsize = (15 , 10 ), dpi = 800 , sharey = False )
271+ axs = axs .flatten ()
272+
273+ for i , activity in enumerate (activities ):
274+ activity_bbaa_pred , activity_actinet_pred , population = build_bland_altman_data (
275+ results_df , activity , ** {group_by .replace (' ' , '_' ).lower (): group })
276+ bland_altman_plot (activity_bbaa_pred , activity_actinet_pred , activity .capitalize (),
277+ ax = axs [i ], show_y_label = True , activity_type = activity )
278+
279+ fig .suptitle (f"Bland-Altman plots for different { group_by .lower ()} in Capture-24 population\n " +
280+ f"{ group_by } : { group } (n={ population } )" , fontsize = 20 )
281+
282+ group_filename = f"{ group_by .lower ().replace (' ' , '_' )} _{ group } .pdf"
283+ save_path_group = f"{ save_path } /{ group_filename } " if save_path else group_filename
284+
285+ plot_and_save_fig (fig , save_path = save_path_group )
286+
287+ plt .close (fig )
0 commit comments