Skip to content

Commit 030bfa1

Browse files
Fix test_metrics (#148)
* modified: sklift/tests/test_metrics.py * deleted: sklift/tests/test_plot_qini_curve.py deleted: sklift/tests/test_plot_uplift_curve.py new file: sklift/tests/test_viz.py
1 parent b73d82c commit 030bfa1

File tree

4 files changed

+278
-248
lines changed

4 files changed

+278
-248
lines changed

sklift/tests/test_metrics.py

Lines changed: 109 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,33 @@ def make_predictions(binary):
4141
def test_uplift_curve(binary, test_x_actual, test_y_actual):
4242
y_true, uplift, treatment = make_predictions(binary)
4343

44-
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
44+
if binary == False:
45+
with pytest.raises(Exception):
46+
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
47+
else:
48+
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
4549

46-
assert_array_almost_equal(x_actual, test_x_actual)
47-
assert_array_almost_equal(y_actual, test_y_actual)
48-
assert x_actual.shape == y_actual.shape
50+
assert_array_almost_equal(x_actual, test_x_actual)
51+
assert_array_almost_equal(y_actual, test_y_actual)
52+
assert x_actual.shape == y_actual.shape
4953

5054

5155
def test_uplift_curve_hard():
52-
y_true, uplift, treatment = make_predictions(binary=True)
53-
y_true = np.zeros(y_true.shape)
56+
with pytest.raises(Exception):
57+
y_true, uplift, treatment = make_predictions(binary=True)
58+
y_true = np.zeros(y_true.shape)
5459

55-
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
60+
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
5661

57-
assert_array_almost_equal(x_actual, np.array([0, 3]))
58-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
62+
assert_array_almost_equal(x_actual, np.array([0, 3]))
63+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
5964

60-
y_true = np.ones(y_true.shape)
65+
y_true = np.ones(y_true.shape)
6166

62-
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
67+
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
6368

64-
assert_array_almost_equal(x_actual, np.array([0, 3]))
65-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
69+
assert_array_almost_equal(x_actual, np.array([0, 3]))
70+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
6671

6772

6873
@pytest.mark.parametrize(
@@ -74,42 +79,35 @@ def test_uplift_curve_hard():
7479
)
7580
def test_perfect_uplift_curve(binary, test_x_actual, test_y_actual):
7681
y_true, uplift, treatment = make_predictions(binary)
77-
78-
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
79-
80-
assert_array_almost_equal(x_actual, test_x_actual)
81-
assert_array_almost_equal(y_actual, test_y_actual)
82-
assert x_actual.shape == y_actual.shape
82+
if binary == False:
83+
with pytest.raises(Exception):
84+
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
85+
else:
86+
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
87+
assert_array_almost_equal(x_actual, test_x_actual)
88+
assert_array_almost_equal(y_actual, test_y_actual)
89+
assert x_actual.shape == y_actual.shape
8390

8491

8592
def test_perfect_uplift_curve_hard():
86-
y_true, uplift, treatment = make_predictions(binary=True)
87-
y_true = np.zeros(y_true.shape)
93+
with pytest.raises(Exception):
94+
y_true, uplift, treatment = make_predictions(binary=True)
95+
y_true = np.zeros(y_true.shape)
8896

89-
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
97+
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
9098

91-
assert_array_almost_equal(x_actual, np.array([0, 1, 3]))
92-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
99+
assert_array_almost_equal(x_actual, np.array([0, 1, 3]))
100+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
93101

94-
y_true = np.ones(y_true.shape)
102+
y_true = np.ones(y_true.shape)
95103

96-
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
104+
x_actual, y_actual = perfect_uplift_curve(y_true, treatment)
97105

98-
assert_array_almost_equal(x_actual, np.array([0, 2, 3]))
99-
assert_array_almost_equal(y_actual, np.array([0.0, 2.0, 0.0]))
106+
assert_array_almost_equal(x_actual, np.array([0, 2, 3]))
107+
assert_array_almost_equal(y_actual, np.array([0.0, 2.0, 0.0]))
100108

101109

102110
def test_uplift_auc_score():
103-
y_true = [1, 1]
104-
uplift = [0.1, 0.3]
105-
treatment = [0, 1]
106-
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), 1.)
107-
108-
y_true = [1, 1]
109-
uplift = [0.1, 0.3]
110-
treatment = [1, 0]
111-
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), -1.)
112-
113111
y_true = [0, 1]
114112
uplift = [0.1, 0.3]
115113
treatment = [1, 0]
@@ -120,15 +118,26 @@ def test_uplift_auc_score():
120118
treatment = [0, 1]
121119
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), 1.)
122120

123-
y_true = [0, 1, 2]
124-
uplift = [0.1, 0.3, 0.9]
125-
treatment = [0, 1, 0]
126-
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), -1.333333)
121+
with pytest.raises(Exception):
122+
y_true = [1, 1]
123+
uplift = [0.1, 0.3]
124+
treatment = [0, 1]
125+
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), 1.)
127126

128-
y_true = [0, 1, 2]
129-
uplift = [0.1, 0.3, 0.9]
130-
treatment = [1, 0, 1]
131-
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), 1.333333)
127+
y_true = [1, 1]
128+
uplift = [0.1, 0.3]
129+
treatment = [1, 0]
130+
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), -1.)
131+
132+
y_true = [0, 1, 2]
133+
uplift = [0.1, 0.3, 0.9]
134+
treatment = [0, 1, 0]
135+
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), -1.333333)
136+
137+
y_true = [0, 1, 2]
138+
uplift = [0.1, 0.3, 0.9]
139+
treatment = [1, 0, 1]
140+
assert_array_almost_equal(uplift_auc_score(y_true, uplift, treatment), 1.333333)
132141

133142

134143
@pytest.mark.parametrize(
@@ -141,37 +150,39 @@ def test_uplift_auc_score():
141150
def test_qini_curve(binary, test_x_actual, test_y_actual):
142151
y_true, uplift, treatment = make_predictions(binary)
143152

144-
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
145-
146-
assert_array_almost_equal(x_actual, test_x_actual)
147-
assert_array_almost_equal(y_actual, test_y_actual)
148-
assert x_actual.shape == y_actual.shape
153+
if binary == False:
154+
with pytest.raises(Exception):
155+
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
156+
else:
157+
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
158+
assert_array_almost_equal(x_actual, test_x_actual)
159+
assert_array_almost_equal(y_actual, test_y_actual)
160+
assert x_actual.shape == y_actual.shape
149161

150162

151163
def test_qini_curve_hard():
152-
y_true, uplift, treatment = make_predictions(binary=True)
153-
y_true = np.zeros(y_true.shape)
164+
with pytest.raises(Exception):
165+
y_true, uplift, treatment = make_predictions(binary=True)
166+
y_true = np.zeros(y_true.shape)
154167

155-
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
168+
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
156169

157-
assert_array_almost_equal(x_actual, np.array([0, 3]))
158-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
170+
assert_array_almost_equal(x_actual, np.array([0, 3]))
171+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
159172

160-
y_true = np.ones(y_true.shape)
173+
y_true = np.ones(y_true.shape)
161174

162-
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
175+
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
163176

164-
assert_array_almost_equal(x_actual, np.array([0, 3]))
165-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
177+
assert_array_almost_equal(x_actual, np.array([0, 3]))
178+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
166179

167180

168181
@pytest.mark.parametrize(
169182
"binary, negative_effect, test_x_actual, test_y_actual",
170183
[
171184
(True, True, np.array([0, 1, 3]), np.array([0., 1., 1.])),
172-
(False, True, np.array([0, 1, 2, 3]), np.array([0., 2., 3., 3.])),
173185
(True, False, np.array([0., 1., 3.]), np.array([0., 1., 1.])),
174-
(False, False, np.array([0., 3., 3.]), np.array([0., 3., 3.]))
175186
]
176187
)
177188
def test_perfect_qini_curve(binary, negative_effect, test_x_actual, test_y_actual):
@@ -185,43 +196,34 @@ def test_perfect_qini_curve(binary, negative_effect, test_x_actual, test_y_actua
185196

186197

187198
def test_perfect_qini_curve_hard():
188-
y_true, uplift, treatment = make_predictions(binary=True)
189-
y_true = np.zeros(y_true.shape)
199+
with pytest.raises(Exception):
200+
y_true, uplift, treatment = make_predictions(binary=True)
201+
y_true = np.zeros(y_true.shape)
190202

191-
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=True)
203+
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=True)
192204

193-
assert_array_almost_equal(x_actual, np.array([0, 3]))
194-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
205+
assert_array_almost_equal(x_actual, np.array([0, 3]))
206+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0]))
195207

196-
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=False)
208+
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=False)
197209

198-
assert_array_almost_equal(x_actual, np.array([0., 0., 3.]))
199-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
210+
assert_array_almost_equal(x_actual, np.array([0., 0., 3.]))
211+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
200212

201-
y_true = np.ones(y_true.shape)
213+
y_true = np.ones(y_true.shape)
202214

203-
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=True)
215+
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=True)
204216

205-
assert_array_almost_equal(x_actual, np.array([0, 2, 3]))
206-
assert_array_almost_equal(y_actual, np.array([0.0, 2.0, 0.0]))
217+
assert_array_almost_equal(x_actual, np.array([0, 2, 3]))
218+
assert_array_almost_equal(y_actual, np.array([0.0, 2.0, 0.0]))
207219

208-
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=False)
220+
x_actual, y_actual = perfect_qini_curve(y_true, treatment, negative_effect=False)
209221

210-
assert_array_almost_equal(x_actual, np.array([0., 0., 3.]))
211-
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
222+
assert_array_almost_equal(x_actual, np.array([0., 0., 3.]))
223+
assert_array_almost_equal(y_actual, np.array([0.0, 0.0, 0.0]))
212224

213225

214226
def test_qini_auc_score():
215-
y_true = [1, 1]
216-
uplift = [0.1, 0.3]
217-
treatment = [0, 1]
218-
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 1.)
219-
220-
y_true = [1, 1]
221-
uplift = [0.1, 0.3]
222-
treatment = [1, 0]
223-
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 0.)
224-
225227
y_true = [0, 1]
226228
uplift = [0.1, 0.3]
227229
treatment = [1, 0]
@@ -232,15 +234,26 @@ def test_qini_auc_score():
232234
treatment = [0, 1]
233235
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 1.)
234236

235-
y_true = [0, 1, 2]
236-
uplift = [0.1, 0.3, 0.9]
237-
treatment = [0, 1, 0]
238-
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), -0.5)
237+
with pytest.raises(Exception):
238+
y_true = [1, 1]
239+
uplift = [0.1, 0.3]
240+
treatment = [0, 1]
241+
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 1.)
242+
243+
y_true = [1, 1]
244+
uplift = [0.1, 0.3]
245+
treatment = [1, 0]
246+
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 0.)
247+
248+
y_true = [0, 1, 2]
249+
uplift = [0.1, 0.3, 0.9]
250+
treatment = [0, 1, 0]
251+
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), -0.5)
239252

240-
y_true = [0, 1, 2]
241-
uplift = [0.1, 0.3, 0.9]
242-
treatment = [1, 0, 1]
243-
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 0.75)
253+
y_true = [0, 1, 2]
254+
uplift = [0.1, 0.3, 0.9]
255+
treatment = [1, 0, 1]
256+
assert_array_almost_equal(qini_auc_score(y_true, uplift, treatment), 0.75)
244257

245258

246259
def test_uplift_at_k():
@@ -301,4 +314,4 @@ def test_treatment_balance_curve():
301314

302315
idx, balance = treatment_balance_curve(uplift, treatment, winsize=2)
303316
assert_array_almost_equal(idx, np.array([1., 100.]))
304-
assert_array_almost_equal(balance, np.array([1., 0.5]))
317+
assert_array_almost_equal(balance, np.array([1., 0.5]))

sklift/tests/test_plot_qini_curve.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

0 commit comments

Comments
 (0)