Skip to content

Commit 1f3f78c

Browse files
authored
Merge pull request #53 from datakind/update-inference-output-file
Update inference output file per discussion
2 parents db37b7d + 7db2d5c commit 1f3f78c

File tree

2 files changed

+56
-46
lines changed

2 files changed

+56
-46
lines changed

src/student_success_tool/modeling/inference.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22

33
import numpy as np
4+
import numpy.typing as npt
45
import pandas as pd
56
from shap import KernelExplainer
67

@@ -9,8 +10,9 @@ def select_top_features_for_display(
910
features: pd.DataFrame,
1011
unique_ids: pd.Series,
1112
predicted_probabilities: list[float],
12-
shap_values: pd.Series,
13+
shap_values: npt.NDArray[np.float64],
1314
n_features: int = 3,
15+
needs_support_threshold_prob: t.Optional[float] = 0.5,
1416
features_table: t.Optional[dict[str, dict[str, str]]] = None,
1517
) -> pd.DataFrame:
1618
"""
@@ -24,6 +26,14 @@ def select_top_features_for_display(
2426
order as unique_ids, of shape len(unique_ids)
2527
shap_values: array of arrays of SHAP values, of shape len(unique_ids)
2628
n_features: number of important features to return
29+
needs_support_threshold_prob: Minimum probability in [0.0, 1.0] used to compute
30+
a boolean "needs support" field added to output records. Values in
31+
``predicted_probabilities`` greater than or equal to this threshold result in
32+
a True value, otherwise it's False; if this threshold is set to null,
33+
then no "needs support" values are added to the output records.
34+
Note that this doesn't have to be the "optimal" decision threshold for
35+
the trained model that produced ``predicted_probabilities`` , it can
36+
be tailored to a school's preferences and use case.
2737
features_table: Optional mapping of column to human-friendly feature name/desc,
2838
loaded via :func:`utils.load_features_table()`
2939
@@ -32,18 +42,26 @@ def select_top_features_for_display(
3242
3343
TODO: refactor this functionality so it's vectorized and aggregates by student
3444
"""
35-
top_features_info = []
45+
pred_probs = np.asarray(predicted_probabilities)
3646

37-
for i, (unique_id, predicted_proba) in enumerate(
38-
zip(unique_ids, predicted_probabilities)
39-
):
47+
top_features_info = []
48+
for i, (unique_id, predicted_proba) in enumerate(zip(unique_ids, pred_probs)):
4049
instance_shap_values = shap_values[i]
4150
top_indices = np.argsort(-np.abs(instance_shap_values))[:n_features]
4251
top_features = features.columns[top_indices]
4352
top_feature_values = features.iloc[i][top_features]
4453
top_shap_values = instance_shap_values[top_indices]
4554

46-
for rank, (feature, feature_value, shap_value) in enumerate(
55+
student_output = {
56+
"Student ID": unique_id,
57+
"Support Score": predicted_proba,
58+
}
59+
if needs_support_threshold_prob is not None:
60+
student_output["Support Needed"] = (
61+
predicted_proba >= needs_support_threshold_prob
62+
)
63+
64+
for feature_rank, (feature, feature_value, shap_value) in enumerate(
4765
zip(top_features, top_feature_values, top_shap_values), start=1
4866
):
4967
feature_name = (
@@ -54,16 +72,18 @@ def select_top_features_for_display(
5472
if features_table is not None
5573
else feature
5674
)
57-
top_features_info.append(
58-
{
59-
"Student ID": unique_id,
60-
"Support Score": predicted_proba,
61-
"Top Indicators": feature_name,
62-
"Indicator Value": feature_value,
63-
"SHAP Value": shap_value,
64-
"Rank": rank,
65-
}
75+
feature_value = (
76+
str(round(feature_value, 2))
77+
if isinstance(feature_value, float)
78+
else str(feature_value)
6679
)
80+
student_output |= {
81+
f"Feature_{feature_rank}_Name": feature_name,
82+
f"Feature_{feature_rank}_Value": feature_value,
83+
f"Feature_{feature_rank}_Importance": round(shap_value, 2),
84+
}
85+
86+
top_features_info.append(student_output)
6787
return pd.DataFrame(top_features_info)
6888

6989

tests/modeling/test_inference.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def explainer():
2828
"predicted_probabilities",
2929
"shap_values",
3030
"n_features",
31+
"needs_support_threshold_prob",
3132
"features_table",
3233
"exp",
3334
],
@@ -37,7 +38,7 @@ def explainer():
3738
{
3839
"x1": ["val1", "val2", "val3"],
3940
"x2": [True, False, True],
40-
"x3": [2.0, 1.0, 0.5],
41+
"x3": [2.0, 1.0001, 0.5],
4142
"x4": [1, 2, 3],
4243
}
4344
),
@@ -47,39 +48,26 @@ def explainer():
4748
[[1.0, 0.9, 0.8, 0.7], [0.0, -1.0, 0.9, -0.8], [0.25, 0.0, -0.5, 0.75]]
4849
),
4950
3,
51+
0.5,
5052
{
5153
"x1": {"name": "feature #1"},
5254
"x2": {"name": "feature #2"},
5355
"x3": {"name": "feature #3"},
5456
},
5557
pd.DataFrame(
5658
{
57-
"Student ID": [1, 1, 1, 2, 2, 2, 3, 3, 3],
58-
"Support Score": [0.9, 0.9, 0.9, 0.1, 0.1, 0.1, 0.5, 0.5, 0.5],
59-
"Top Indicators": [
60-
"feature #1",
61-
"feature #2",
62-
"feature #3",
63-
"feature #2",
64-
"feature #3",
65-
"x4",
66-
"x4",
67-
"feature #3",
68-
"feature #1",
69-
],
70-
"Indicator Value": [
71-
"val1",
72-
True,
73-
2.0,
74-
False,
75-
1.0,
76-
2,
77-
3,
78-
0.5,
79-
"val3",
80-
],
81-
"SHAP Value": [1.0, 0.9, 0.8, -1.0, 0.9, -0.8, 0.75, -0.5, 0.25],
82-
"Rank": [1, 2, 3, 1, 2, 3, 1, 2, 3],
59+
"Student ID": [1, 2, 3],
60+
"Support Score": [0.9, 0.1, 0.5],
61+
"Support Needed": [True, False, True],
62+
"Feature_1_Name": ["feature #1", "feature #2", "x4"],
63+
"Feature_1_Value": ["val1", "False", "3"],
64+
"Feature_1_Importance": [1.0, -1.0, 0.75],
65+
"Feature_2_Name": ["feature #2", "feature #3", "feature #3"],
66+
"Feature_2_Value": ["True", "1.0", "0.5"],
67+
"Feature_2_Importance": [0.9, 0.9, -0.5],
68+
"Feature_3_Name": ["feature #3", "x4", "feature #1"],
69+
"Feature_3_Value": ["2.0", "2", "val3"],
70+
"Feature_3_Importance": [0.8, -0.8, 0.25],
8371
}
8472
),
8573
),
@@ -99,14 +87,14 @@ def explainer():
9987
),
10088
1,
10189
None,
90+
None,
10291
pd.DataFrame(
10392
{
10493
"Student ID": [1, 2, 3],
10594
"Support Score": [0.9, 0.1, 0.5],
106-
"Top Indicators": ["x1", "x2", "x4"],
107-
"Indicator Value": ["val1", False, 3],
108-
"SHAP Value": [1.0, -1.0, 0.75],
109-
"Rank": [1, 1, 1],
95+
"Feature_1_Name": ["x1", "x2", "x4"],
96+
"Feature_1_Value": ["val1", "False", "3"],
97+
"Feature_1_Importance": [1.0, -1.0, 0.75],
11098
}
11199
),
112100
),
@@ -118,6 +106,7 @@ def test_select_top_features_for_display(
118106
predicted_probabilities,
119107
shap_values,
120108
n_features,
109+
needs_support_threshold_prob,
121110
features_table,
122111
exp,
123112
):
@@ -127,6 +116,7 @@ def test_select_top_features_for_display(
127116
predicted_probabilities,
128117
shap_values,
129118
n_features=n_features,
119+
needs_support_threshold_prob=needs_support_threshold_prob,
130120
features_table=features_table,
131121
)
132122
assert isinstance(obs, pd.DataFrame) and not obs.empty

0 commit comments

Comments
 (0)