Skip to content

Commit 45392c5

Browse files
authored
Merge pull request #24 from explainX/cohort_analysis_update
apps updated
2 parents 94d953e + 6995fe1 commit 45392c5

File tree

4 files changed

+329
-1
lines changed

4 files changed

+329
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ dist/
1515
downloads/
1616
eggs/
1717
.eggs/
18-
lib/
1918
lib64/
2019
parts/
2120
sdist/

lib/apps/cohort.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from imports import *
2+
from plotly_graphs import *
3+
from plotly_css import *
4+
from app import app
5+
6+
def cohort_layout(original_variables):
7+
8+
var_name_dropdown = html.Div([
9+
html.P("Choose Variable"),
10+
dcc.Dropdown(
11+
id='demo-dropdown',
12+
options=[{'label': i, 'value': i} for i in original_variables],
13+
value= "",
14+
clearable=False
15+
)
16+
])
17+
18+
operators_list = ["==",">","<",">=","<="]
19+
20+
operators_dropdown = html.Div([
21+
html.P("Choose Operator"),
22+
dcc.Dropdown(id="demo-operators",
23+
options=[{"label":i, "value":i} for i in operators_list],
24+
value = "",
25+
clearable=False
26+
)
27+
])
28+
29+
value_input = html.Div([
30+
html.P("Enter Value"),
31+
dcc.Input(id="demo-values",
32+
type="text",
33+
value="",
34+
debounce=True)
35+
])
36+
37+
x_axis_dropdown = html.Div([
38+
html.P("Choose X-Axis Variable"),
39+
dcc.Dropdown(id="x-axis",
40+
options = [{"label":i, "value":i} for i in original_variables[-4:]],
41+
value = original_variables[-2],
42+
clearable=False)
43+
], style={"width":"30%", "padding-left":"20px"})
44+
45+
46+
modal = html.Div(
47+
[
48+
dbc.Modal(
49+
[
50+
dbc.ModalHeader("Cohort Analysis"),
51+
dbc.ModalBody(
52+
html.Div(
53+
[var_name_dropdown,
54+
operators_dropdown,
55+
value_input
56+
], id="modal_body")
57+
),
58+
dbc.ModalFooter([
59+
dbc.Button("Add Cohort", id="add-cohort", n_clicks=3),
60+
dbc.Button("Close", id="close", className="ml-auto")
61+
])],
62+
id="modal",
63+
),
64+
], id="modal-parent"
65+
)
66+
67+
button = dbc.Button("Add Cohort", id="open")
68+
69+
remove_button = dbc.Button("Remove Cohort", id="remove-cohort", style={"margin-left":"20px"})
70+
71+
cohort_details = html.Div(id="cohort-details", children=[], style={"display":"flex"})
72+
73+
cohort_metrics_div = html.Div(id="cohort-metrics-div", children = [], style={"display":"flex"})
74+
75+
heading = html.H3("Evaluate Model Performance - Cohort Analysis", style={"padding-left":"20px", "padding-top":"20px"})
76+
77+
details = html.P("Evaluate the performance of your model by exploring the distribution of your prediction value and the values of your model performance metrics. You can further investigate your model by looking at a comparative analysis of its performance across different cohorts or subgroups of your dataset. Select filters along y-value and x-value to cut across different dimensions.", style={"padding-left":'20px'})
78+
79+
80+
layout = html.Div(
81+
[
82+
heading,
83+
details,
84+
x_axis_dropdown,
85+
html.Div([button, remove_button],style={"padding":"20px", "display":"flex"}),
86+
cohort_details,
87+
cohort_metrics_div,
88+
modal,
89+
html.Div(id="cohort-graph")
90+
91+
], id="main"
92+
)
93+
94+
return layout

lib/apps/webapp/static/xai.png

12.2 KB
Loading

lib/cohort_analysis.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from imports import *
2+
3+
class cohortAnalysis():
4+
def __init__(self):
5+
self.cohorts = {}
6+
self.cohort_metrics = []
7+
self.cohort_set = {}
8+
9+
10+
def filtered_dataframe(self, df, filter_variable, var_name="", operator="", value=""):
11+
"""
12+
data = main_data
13+
name = cohort_name
14+
var_name = name of the variable to slice/dice
15+
operator: >, <, =, >=, <=
16+
value = value of the variable
17+
18+
returns main_data: filtered dataset with just the probabilities
19+
name: filtered dataset with the condition
20+
"""
21+
main_dataset = df[filter_variable]
22+
if (var_name != "") or (operator != "") or (value != ""):
23+
if len(df[filter_variable]) != 0:
24+
name = df.query("{} {} {}".format(var_name, operator, value))[filter_variable]
25+
condition = str(var_name)+str(operator)+str(value)
26+
return main_dataset, name, condition
27+
else:
28+
pass
29+
else:
30+
if len(df[filter_variable]) != 0:
31+
condition = "All Data"
32+
return main_dataset, condition
33+
else:
34+
pass
35+
36+
def add_cohort(self, df, filter_variable, var_name="", operator="", value=""):
37+
if (var_name != "") or (operator != "") or (value != ""):
38+
main_dataset, name, condition = self.filtered_dataframe(df,filter_variable,var_name,operator,value)
39+
self.cohorts[condition] = name
40+
else:
41+
main_dataset, condition = self.filtered_dataframe(df, filter_variable )
42+
self.cohorts[condition] = main_dataset
43+
44+
def remove_cohort(self):
45+
if (len(self.cohorts) >1) and (len(self.cohort_set) > 1):
46+
self.cohorts.popitem()
47+
self.cohort_set.popitem()
48+
else:
49+
pass
50+
51+
def add_cohort_metrics(self, df, var_name="", operator="", value="", is_classification=True):
52+
"""
53+
data = main_data
54+
name = cohort_name
55+
var_name = name of the variable to slice/dice
56+
operator: >, <, =, >=, <=
57+
value = value of the variable
58+
59+
"""
60+
if value != "":
61+
#Extract filtered predicted values
62+
_, predicted, condition_predict = self.filtered_dataframe(df, "y_prediction",var_name,operator,value)
63+
#Extract filtered true labels
64+
_, true_values, condition_true = self.filtered_dataframe(df, "y_actual", var_name, operator, value)
65+
#calculate metrics
66+
if is_classification is True:
67+
if len(true_values) != 0:
68+
accuracy, precision, recall, fpr, fnr = self.classification_cohort_metrics(true_values, predicted)
69+
self.cohort_set[condition_predict] = self.generate_classification_divs(accuracy, precision, recall, fpr, fnr)
70+
else:
71+
pass
72+
else:
73+
if len(true_values) != 0:
74+
mae, mse, r2 = self.regression_cohort_metrics(true_values, predicted)
75+
#save these metrics to an array
76+
self.cohort_set[condition_predict] = self.generator_regression_divs(mae, mse, r2)
77+
else:
78+
pass
79+
else:
80+
main_dataset, condition = self.filtered_dataframe(df, "y_prediction")
81+
true_data, _ = self.filtered_dataframe(df, "y_actual")
82+
if is_classification is True:
83+
if len(true_data) != 0:
84+
accuracy, precision, recall, fpr, fnr = self.classification_cohort_metrics(true_data,main_dataset)
85+
self.cohort_set[condition] = self.generate_classification_divs(accuracy, precision, recall, fpr, fnr)
86+
else:
87+
pass
88+
else:
89+
if len(true_data) != 0:
90+
mae, mse, r2 = self.regression_cohort_metrics(true_data, main_dataset)
91+
#save these metrics to an array
92+
self.cohort_set[condition] = self.generator_regression_divs(mae, mse, r2)
93+
else:
94+
pass
95+
96+
def generate_classification_divs(self, accuracy, precision, recall, fpr, fnr):
97+
metrics_div = [html.Div("Accuracy: {}".format(accuracy)),
98+
html.Div("Precision: {}".format(precision)),
99+
html.Div("Recall: {}".format(recall)),
100+
html.Div("fpr: {}".format(fpr)),
101+
html.Div("fnr: {}".format(fnr))
102+
]
103+
return metrics_div
104+
105+
def generator_regression_divs(self, mae, mse, r2):
106+
metrics_div = [html.Div("MAE : {}".format(mae)),
107+
html.Div("MSE : {}".format(mse)),
108+
html.Div("R2: {}".format(r2))]
109+
return metrics_div
110+
111+
def cohort_details(self):
112+
"""
113+
Cohort Name
114+
Length of Cohort
115+
"""
116+
length_dict = {key: len(value) for key, value in self.cohorts.items()}
117+
divs = []
118+
for i in range(len(length_dict)):
119+
if list(length_dict.values())[i] != 0:
120+
first_html = html.Div(list(length_dict.keys())[i])
121+
second_html = html.Div(str(list(length_dict.values())[i])+" datapoints")
122+
divs.append(html.Div([first_html,second_html], style={"padding-left":"20px","padding-right":"20px","padding-bottom":"0px","width":"200px"}))
123+
else:
124+
pass
125+
return divs
126+
127+
def cohort_metrics_details(self):
128+
"""
129+
Cohort Name
130+
Metrics
131+
"""
132+
length_dict = {key: value for key, value in self.cohort_set.items()}
133+
div_metrics = []
134+
for i in range(len(length_dict)):
135+
div_metrics.append(html.Div(list(length_dict.values())[i], style={"padding-left":"20px","padding-right":"20px","padding-bottom":"0px","width":"200px"}))
136+
return div_metrics
137+
138+
139+
def cohort_graph(self, filter_variable):
140+
"""[This function generators the box plot for the cohorts. This is operated directly from the frontend.]
141+
142+
Args:
143+
filter_variable ([string]): [This variable is x-axis value of the graph. It can be either probabilities or model prediction values]
144+
145+
Returns:
146+
[figure]: [box plot graph]
147+
"""
148+
149+
X_Value = str(filter_variable)
150+
Y_Value = 'Cohorts'
151+
152+
fig = go.Figure()
153+
154+
for k, v in self.cohorts.items():
155+
fig.add_trace(go.Box(x=v, name=k))
156+
157+
fig.update_layout(
158+
yaxis_title = Y_Value,
159+
xaxis_title = X_Value,
160+
template = "plotly_white",
161+
font=dict(
162+
size=8,
163+
)
164+
)
165+
fig.update_layout(legend=dict(
166+
orientation="h",
167+
yanchor="bottom",
168+
y=1.02,
169+
xanchor="right",
170+
x=1
171+
))
172+
173+
fig.update_layout(
174+
margin={"t":0},
175+
)
176+
return fig
177+
178+
def classification_cohort_metrics(self, y_true, predicted):
179+
"""[Calculates the metrics for classification problems]
180+
181+
Args:
182+
y_true ([type]): [True labels from the dataset]
183+
predicted ([type]): [Predicted values from the model]
184+
185+
Returns:
186+
Accuracy metric of the model
187+
Precision value of the model
188+
Recall value of the model
189+
False Positive Rate
190+
False Negative Rate
191+
"""
192+
if len(y_true) != 0:
193+
#Accuracy
194+
accuracy = round(metrics.accuracy_score(y_true, predicted),2) #Accuracy classification score.
195+
#Precision
196+
precision = round(metrics.precision_score(y_true, predicted, zero_division=1),2) #Compute the precision
197+
#Recall
198+
recall = round(metrics.recall_score(y_true, predicted, zero_division=1),2) #Compute the recall
199+
#False Positive Rate (FPR)
200+
tn, fp, fn, tp = metrics.confusion_matrix(y_true, predicted).ravel() #Compute confusion matrix to evaluate the accuracy of a classification.
201+
#False Negative Rate (FNR)
202+
fpr = round((fp/(fp+tn)),2)
203+
fnr = round((fn/(tp+fn) if (tp+fn) else 0),2)
204+
205+
return accuracy, precision, recall, fpr, fnr
206+
else:
207+
pass
208+
209+
def regression_cohort_metrics(self, y_true, predicted):
210+
"""[Calculates the metrics for regression problems]
211+
212+
Args:
213+
y_true ([type]): [True labels from the dataset]
214+
predicted ([type]): [Predicted values from the model]
215+
216+
Returns:
217+
Mean Absolute Error
218+
Mean Squared Error
219+
R-Squared Value
220+
"""
221+
if len(y_true) != 0:
222+
#Mean Absolute Error
223+
mae = round(metrics.mean_absolute_error(y_true, predicted),2)
224+
#Mean Squared Error
225+
mse = round(metrics.mean_squared_error(y_true, predicted),2)
226+
#R2
227+
r2 = round(metrics.r2_score(y_true, predicted),2)
228+
229+
return mae, mse, r2
230+
else:
231+
pass
232+
233+
234+
235+

0 commit comments

Comments
 (0)