1+ from imports import *
2+
3+ class cohortAnalysis ():
4+ def __init__ (self ):
5+ self .cohorts = {}
6+ self .cohort_metrics = []
7+ self .cohort_set = {}
8+
9+
10+ def filtered_dataframe (self , df , filter_variable , var_name = "" , operator = "" , value = "" ):
11+ """
12+ data = main_data
13+ name = cohort_name
14+ var_name = name of the variable to slice/dice
15+ operator: >, <, =, >=, <=
16+ value = value of the variable
17+
18+ returns main_data: filtered dataset with just the probabilities
19+ name: filtered dataset with the condition
20+ """
21+ main_dataset = df [filter_variable ]
22+ if (var_name != "" ) or (operator != "" ) or (value != "" ):
23+ if len (df [filter_variable ]) != 0 :
24+ name = df .query ("{} {} {}" .format (var_name , operator , value ))[filter_variable ]
25+ condition = str (var_name )+ str (operator )+ str (value )
26+ return main_dataset , name , condition
27+ else :
28+ pass
29+ else :
30+ if len (df [filter_variable ]) != 0 :
31+ condition = "All Data"
32+ return main_dataset , condition
33+ else :
34+ pass
35+
36+ def add_cohort (self , df , filter_variable , var_name = "" , operator = "" , value = "" ):
37+ if (var_name != "" ) or (operator != "" ) or (value != "" ):
38+ main_dataset , name , condition = self .filtered_dataframe (df ,filter_variable ,var_name ,operator ,value )
39+ self .cohorts [condition ] = name
40+ else :
41+ main_dataset , condition = self .filtered_dataframe (df , filter_variable )
42+ self .cohorts [condition ] = main_dataset
43+
44+ def remove_cohort (self ):
45+ if (len (self .cohorts ) > 1 ) and (len (self .cohort_set ) > 1 ):
46+ self .cohorts .popitem ()
47+ self .cohort_set .popitem ()
48+ else :
49+ pass
50+
51+ def add_cohort_metrics (self , df , var_name = "" , operator = "" , value = "" , is_classification = True ):
52+ """
53+ data = main_data
54+ name = cohort_name
55+ var_name = name of the variable to slice/dice
56+ operator: >, <, =, >=, <=
57+ value = value of the variable
58+
59+ """
60+ if value != "" :
61+ #Extract filtered predicted values
62+ _ , predicted , condition_predict = self .filtered_dataframe (df , "y_prediction" ,var_name ,operator ,value )
63+ #Extract filtered true labels
64+ _ , true_values , condition_true = self .filtered_dataframe (df , "y_actual" , var_name , operator , value )
65+ #calculate metrics
66+ if is_classification is True :
67+ if len (true_values ) != 0 :
68+ accuracy , precision , recall , fpr , fnr = self .classification_cohort_metrics (true_values , predicted )
69+ self .cohort_set [condition_predict ] = self .generate_classification_divs (accuracy , precision , recall , fpr , fnr )
70+ else :
71+ pass
72+ else :
73+ if len (true_values ) != 0 :
74+ mae , mse , r2 = self .regression_cohort_metrics (true_values , predicted )
75+ #save these metrics to an array
76+ self .cohort_set [condition_predict ] = self .generator_regression_divs (mae , mse , r2 )
77+ else :
78+ pass
79+ else :
80+ main_dataset , condition = self .filtered_dataframe (df , "y_prediction" )
81+ true_data , _ = self .filtered_dataframe (df , "y_actual" )
82+ if is_classification is True :
83+ if len (true_data ) != 0 :
84+ accuracy , precision , recall , fpr , fnr = self .classification_cohort_metrics (true_data ,main_dataset )
85+ self .cohort_set [condition ] = self .generate_classification_divs (accuracy , precision , recall , fpr , fnr )
86+ else :
87+ pass
88+ else :
89+ if len (true_data ) != 0 :
90+ mae , mse , r2 = self .regression_cohort_metrics (true_data , main_dataset )
91+ #save these metrics to an array
92+ self .cohort_set [condition ] = self .generator_regression_divs (mae , mse , r2 )
93+ else :
94+ pass
95+
96+ def generate_classification_divs (self , accuracy , precision , recall , fpr , fnr ):
97+ metrics_div = [html .Div ("Accuracy: {}" .format (accuracy )),
98+ html .Div ("Precision: {}" .format (precision )),
99+ html .Div ("Recall: {}" .format (recall )),
100+ html .Div ("fpr: {}" .format (fpr )),
101+ html .Div ("fnr: {}" .format (fnr ))
102+ ]
103+ return metrics_div
104+
105+ def generator_regression_divs (self , mae , mse , r2 ):
106+ metrics_div = [html .Div ("MAE : {}" .format (mae )),
107+ html .Div ("MSE : {}" .format (mse )),
108+ html .Div ("R2: {}" .format (r2 ))]
109+ return metrics_div
110+
111+ def cohort_details (self ):
112+ """
113+ Cohort Name
114+ Length of Cohort
115+ """
116+ length_dict = {key : len (value ) for key , value in self .cohorts .items ()}
117+ divs = []
118+ for i in range (len (length_dict )):
119+ if list (length_dict .values ())[i ] != 0 :
120+ first_html = html .Div (list (length_dict .keys ())[i ])
121+ second_html = html .Div (str (list (length_dict .values ())[i ])+ " datapoints" )
122+ divs .append (html .Div ([first_html ,second_html ], style = {"padding-left" :"20px" ,"padding-right" :"20px" ,"padding-bottom" :"0px" ,"width" :"200px" }))
123+ else :
124+ pass
125+ return divs
126+
127+ def cohort_metrics_details (self ):
128+ """
129+ Cohort Name
130+ Metrics
131+ """
132+ length_dict = {key : value for key , value in self .cohort_set .items ()}
133+ div_metrics = []
134+ for i in range (len (length_dict )):
135+ div_metrics .append (html .Div (list (length_dict .values ())[i ], style = {"padding-left" :"20px" ,"padding-right" :"20px" ,"padding-bottom" :"0px" ,"width" :"200px" }))
136+ return div_metrics
137+
138+
139+ def cohort_graph (self , filter_variable ):
140+ """[This function generators the box plot for the cohorts. This is operated directly from the frontend.]
141+
142+ Args:
143+ filter_variable ([string]): [This variable is x-axis value of the graph. It can be either probabilities or model prediction values]
144+
145+ Returns:
146+ [figure]: [box plot graph]
147+ """
148+
149+ X_Value = str (filter_variable )
150+ Y_Value = 'Cohorts'
151+
152+ fig = go .Figure ()
153+
154+ for k , v in self .cohorts .items ():
155+ fig .add_trace (go .Box (x = v , name = k ))
156+
157+ fig .update_layout (
158+ yaxis_title = Y_Value ,
159+ xaxis_title = X_Value ,
160+ template = "plotly_white" ,
161+ font = dict (
162+ size = 8 ,
163+ )
164+ )
165+ fig .update_layout (legend = dict (
166+ orientation = "h" ,
167+ yanchor = "bottom" ,
168+ y = 1.02 ,
169+ xanchor = "right" ,
170+ x = 1
171+ ))
172+
173+ fig .update_layout (
174+ margin = {"t" :0 },
175+ )
176+ return fig
177+
178+ def classification_cohort_metrics (self , y_true , predicted ):
179+ """[Calculates the metrics for classification problems]
180+
181+ Args:
182+ y_true ([type]): [True labels from the dataset]
183+ predicted ([type]): [Predicted values from the model]
184+
185+ Returns:
186+ Accuracy metric of the model
187+ Precision value of the model
188+ Recall value of the model
189+ False Positive Rate
190+ False Negative Rate
191+ """
192+ if len (y_true ) != 0 :
193+ #Accuracy
194+ accuracy = round (metrics .accuracy_score (y_true , predicted ),2 ) #Accuracy classification score.
195+ #Precision
196+ precision = round (metrics .precision_score (y_true , predicted , zero_division = 1 ),2 ) #Compute the precision
197+ #Recall
198+ recall = round (metrics .recall_score (y_true , predicted , zero_division = 1 ),2 ) #Compute the recall
199+ #False Positive Rate (FPR)
200+ tn , fp , fn , tp = metrics .confusion_matrix (y_true , predicted ).ravel () #Compute confusion matrix to evaluate the accuracy of a classification.
201+ #False Negative Rate (FNR)
202+ fpr = round ((fp / (fp + tn )),2 )
203+ fnr = round ((fn / (tp + fn ) if (tp + fn ) else 0 ),2 )
204+
205+ return accuracy , precision , recall , fpr , fnr
206+ else :
207+ pass
208+
209+ def regression_cohort_metrics (self , y_true , predicted ):
210+ """[Calculates the metrics for regression problems]
211+
212+ Args:
213+ y_true ([type]): [True labels from the dataset]
214+ predicted ([type]): [Predicted values from the model]
215+
216+ Returns:
217+ Mean Absolute Error
218+ Mean Squared Error
219+ R-Squared Value
220+ """
221+ if len (y_true ) != 0 :
222+ #Mean Absolute Error
223+ mae = round (metrics .mean_absolute_error (y_true , predicted ),2 )
224+ #Mean Squared Error
225+ mse = round (metrics .mean_squared_error (y_true , predicted ),2 )
226+ #R2
227+ r2 = round (metrics .r2_score (y_true , predicted ),2 )
228+
229+ return mae , mse , r2
230+ else :
231+ pass
232+
233+
234+
235+
0 commit comments