Skip to content

Commit da116c6

Browse files
committed
Update tests
1 parent 121789b commit da116c6

File tree

6 files changed

+6
-10
lines changed

6 files changed

+6
-10
lines changed

ml_inspector/_metrics_curves.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ class as well as for the micro-average.
118118
"""
119119
curve_data = {}
120120
for i, c in enumerate(classes):
121-
if i == 0:
122-
print(y_true == c, y_prob[:, i])
123121
curve_data[c] = self.curve_function(y_true == c, y_prob[:, i])
124122
curve_data["Average"] = self.curve_function(
125123
label_binarize(y_true, classes=classes).ravel(), y_prob.ravel()

ml_inspector/gain_curves.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ def calculate_gain_curve(y_true: np.ndarray, y_prob: np.ndarray) -> tuple:
2727
* the corresponding thresholds
2828
"""
2929
y_prob = pd.Series(np.array(y_prob)).sort_values(ascending=False)
30-
y_true = pd.Series(np.array(y_true).astype(int)).reindex_like(y_prob)
31-
print("y_prob", y_prob)
32-
print("y_true", y_true)
30+
y_true = pd.Series(np.array(y_true)).reindex_like(y_prob)
3331
recalls = y_true.cumsum() / y_true.sum()
3432
fractions = [i / len(y_true) for i in range(len(y_true))]
3533
thresholds = y_prob
36-
return np.array(fractions).round(3), np.array(recalls), np.array(thresholds)
34+
return np.array(fractions), np.array(recalls), np.array(thresholds)
3735

3836

3937
class GainCurves(MetricsCurves):

tests/fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def multiclass_predictions():
3131
[0.2, 0.6, 0.1, 0.1],
3232
[0.4, 0.1, 0.2, 0.3],
3333
[0.2, 0.8, 0.0, 0.0],
34-
[0.3, 0.4, 0.2, 0.1],
34+
[0.2, 0.4, 0.3, 0.1],
3535
[0.6, 0.1, 0.1, 0.2],
3636
]
3737
)

tests/test_gain_curves.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_plot_gain_curves_multi_class(self, multiclass_predictions):
3232
assert fig.data[3]["name"] == "Class 3 (Training): AUC=0.83"
3333
assert fig.data[4]["name"] == "Micro-average Gain curve (Training): AUC=0.84"
3434
print(fig.data[5])
35-
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.72"
35+
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.71"
3636
assert fig.data[-1]["name"] == "Random decision: AUC=0.50"
3737

3838
def test_plot_gain_curves_error_single_class(self, binary_predictions):

tests/test_precision_recall_curves.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_plot_pr_curves_multi_class(self, multiclass_predictions):
3333
assert fig.data[4]["name"] == (
3434
"Micro-average Precision-Recall curve (Training): AUC=0.90"
3535
)
36-
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.75"
36+
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.79"
3737
assert fig.data[-1]["name"] == "Random decision: AUC=0.25"
3838

3939
def test_plot_pr_curves_error_single_class(self, binary_predictions):

tests/test_roc_curves.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_plot_roc_curves_multi_class(self, multiclass_predictions):
3131
assert fig.data[2]["name"] == "Class 2 (Training): AUC=1.00"
3232
assert fig.data[3]["name"] == "Class 3 (Training): AUC=1.00"
3333
assert fig.data[4]["name"] == "Micro-average ROC curve (Training): AUC=0.96"
34-
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.81"
34+
assert fig.data[5]["name"] == "Class 0 (Test): AUC=0.88"
3535
assert fig.data[-1]["name"] == "Random decision: AUC=0.50"
3636

3737
def test_plot_roc_curves_error_single_class(self, binary_predictions):

0 commit comments

Comments
 (0)