Skip to content

Commit edd69ab

Browse files
committed
Merge branch 'feature/fix_qini_score' into dev
2 parents 1f5fdb9 + 20c6ccc commit edd69ab

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklift/metrics/metrics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,24 +176,25 @@ def qini_auc_score(y_true, uplift, treatment, negative_effect=True):
176176
"""
177177
# ToDO: Add Continuous Outcomes
178178
check_consistent_length(y_true, uplift, treatment)
179+
n_samples = len(y_true)
179180

180181
y_true, uplift, treatment = np.array(y_true), np.array(uplift), np.array(treatment)
181182

182183
if not isinstance(negative_effect, bool):
183184
raise TypeError(f'Negative_effects flag should be bool, got: {type(negative_effect)}')
184185

185186
ratio_random = (y_true[treatment == 1].sum() - len(y_true[treatment == 1]) *
186-
y_true[treatment == 0].sum() / len(y_true[treatment == 0])) / len(y_true)
187+
y_true[treatment == 0].sum() / len(y_true[treatment == 0]))
187188

188-
x_baseline, y_baseline = np.array([0, len(y_true)]), np.array([0, ratio_random])
189+
x_baseline, y_baseline = np.array([0, n_samples]), np.array([0, ratio_random])
189190
auc_score_baseline = auc(x_baseline, y_baseline)
190191

191192
if negative_effect:
192193
x_perfect, y_perfect = qini_curve(
193194
y_true, y_true * treatment - y_true * (1 - treatment), treatment
194195
)
195196
else:
196-
x_perfect, y_perfect = np.array([0, ratio_random, len(y_true)]), np.array([0, ratio_random, ratio_random])
197+
x_perfect, y_perfect = np.array([0, ratio_random, n_samples]), np.array([0, ratio_random, ratio_random])
197198

198199
qini_auc_score_perfect = auc(x_perfect, y_perfect) - auc_score_baseline
199200
qini_auc_score_actual = auc(*qini_curve(y_true, uplift, treatment)) - auc_score_baseline

0 commit comments

Comments
 (0)