Skip to content

Commit ff2613e

Browse files
committed
🔨 Add uplift plot_uplift_curve and plot_qini_curve; Remove plot_uplift_qini_curves
1 parent 2b9cdd2 commit ff2613e

File tree

3 files changed

+86
-66
lines changed

3 files changed

+86
-66
lines changed

sklift/metrics/metrics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def uplift_curve(y_true, uplift, treatment):
2323
2424
:func:`.perfect_uplift_curve`: Compute the perfect Uplift curve.
2525
26-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
26+
:func:`.plot_uplift_curve`: Plot Uplift curves from predictions.
2727
2828
:func:`.qini_curve`: Compute Qini curve.
2929
@@ -82,7 +82,7 @@ def perfect_uplift_curve(y_true, treatment):
8282
8383
:func:`.uplift_auc_score`: Compute normalized Area Under the Uplift curve from prediction scores.
8484
85-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
85+
:func:`.plot_uplift_curve`: Plot Uplift curves from predictions.
8686
"""
8787
check_consistent_length(y_true, treatment)
8888
y_true, treatment = np.array(y_true), np.array(treatment)
@@ -116,7 +116,7 @@ def uplift_auc_score(y_true, uplift, treatment):
116116
117117
:func:`.perfect_uplift_curve`: Compute the perfect (optimum) Uplift curve.
118118
119-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
119+
:func:`.plot_uplift_curve`: Plot Uplift curves from predictions.
120120
121121
:func:`.qini_auc_score`: Compute normalized Area Under the Qini Curve from prediction scores.
122122
"""
@@ -153,7 +153,7 @@ def qini_curve(y_true, uplift, treatment):
153153
154154
:func:`.perfect_qini_curve`: Compute the perfect Qini curve.
155155
156-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
156+
:func:`.plot_qini_curves`: Plot Qini curves from predictions..
157157
158158
:func:`.uplift_curve`: Compute Uplift curve.
159159
@@ -217,7 +217,7 @@ def perfect_qini_curve(y_true, treatment, negative_effect=True):
217217
218218
:func:`.qini_auc_score`: Compute the area under the Qini curve.
219219
220-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
220+
:func:`.plot_qini_curves`: Plot Qini curves from predictions..
221221
"""
222222
check_consistent_length(y_true, treatment)
223223
n_samples = len(y_true)
@@ -264,7 +264,7 @@ def qini_auc_score(y_true, uplift, treatment, negative_effect=True):
264264
265265
:func:`.perfect_qini_curve`: Compute the perfect (optimum) Qini curve.
266266
267-
:func:`.plot_uplift_qini_curves`: Plot Uplift and Qini curves.
267+
:func:`.plot_qini_curves`: Plot Qini curves from predictions..
268268
269269
:func:`.uplift_auc_score`: Compute normalized Area Under the Uplift curve from prediction scores.
270270

sklift/viz/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
from .base import plot_uplift_preds, plot_uplift_qini_curves, plot_uplift_by_percentile, plot_treatment_balance_curve
1+
from .base import (
2+
plot_uplift_curve, plot_qini_curve, plot_uplift_preds,
3+
plot_uplift_by_percentile, plot_treatment_balance_curve
4+
)
25

3-
__all__ = [plot_uplift_preds, plot_uplift_qini_curves, plot_uplift_by_percentile, plot_treatment_balance_curve]
6+
__all__ = [
7+
plot_uplift_curve, plot_qini_curve, plot_uplift_preds,
8+
plot_uplift_by_percentile, plot_treatment_balance_curve
9+
]

sklift/viz/base.py

Lines changed: 72 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import numpy as np
21
import matplotlib.pyplot as plt
2+
import numpy as np
33
from sklearn.utils.validation import check_consistent_length
4+
45
from ..metrics import (
5-
uplift_curve,
6-
auuc,
7-
qini_curve,
8-
auqc,
9-
response_rate_by_percentile,
10-
uplift_by_percentile,
6+
uplift_curve, perfect_uplift_curve, uplift_auc_score,
7+
qini_curve, perfect_qini_curve, qini_auc_score,
118
treatment_balance_curve
129
)
1310

@@ -18,11 +15,11 @@ def plot_uplift_preds(trmnt_preds, ctrl_preds, log=False, bins=100):
1815
Args:
1916
trmnt_preds (1d array-like): Predictions for all observations if they are treatment.
2017
ctrl_preds (1d array-like): Predictions for all observations if they are control.
21-
log (bool, default False): Logarithm of source samples.
18+
log (bool, default False): Logarithm of source samples. Default is False.
2219
bins (integer or sequence, default 100): Number of histogram bins to be used.
2320
If an integer is given, bins + 1 bin edges are calculated and returned.
2421
If bins is a sequence, gives bin edges, including left edge of first bin and right edge of last bin.
25-
In this case, bins is returned unmodified.
22+
In this case, bins is returned unmodified. Default is 100.
2623
2724
Returns:
2825
Object that stores computed values.
@@ -57,66 +54,83 @@ def plot_uplift_preds(trmnt_preds, ctrl_preds, log=False, bins=100):
5754
return axes
5855

5956

60-
def plot_uplift_qini_curves(y_true, uplift, treatment, random=True, perfect=False):
61-
"""Plot Uplift and Qini curves.
57+
def plot_uplift_curve(y_true, uplift, treatment, random=True, perfect=True):
58+
"""Plot Uplift curves from predictions.
6259
6360
Args:
6461
y_true (1d array-like): Ground truth (correct) labels.
6562
uplift (1d array-like): Predicted uplift, as returned by a model.
6663
treatment (1d array-like): Treatment labels.
67-
random (bool, default True): Draw a random curve.
68-
perfect (bool, default False): Draw a perfect curve.
64+
random (bool, default True): Draw a random curve. Default is True.
65+
perfect (bool, default False): Draw a perfect curve. Default is True.
6966
7067
Returns:
7168
Object that stores computed values.
7269
"""
7370
check_consistent_length(y_true, uplift, treatment)
7471
y_true, uplift, treatment = np.array(y_true), np.array(uplift), np.array(treatment)
7572

76-
x_up, y_up = uplift_curve(y_true, uplift, treatment)
77-
x_qi, y_qi = qini_curve(y_true, uplift, treatment)
73+
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 6))
7874

79-
fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(14, 7))
80-
81-
axes[0].plot(x_up, y_up, label='Model', color='b')
82-
axes[1].plot(x_qi, y_qi, label='Model', color='b')
75+
x_actual, y_actual = uplift_curve(y_true, uplift, treatment)
76+
ax.plot(x_actual, y_actual, label='Model', color='blue')
8377

8478
if random:
85-
up_ratio_random = (y_true[treatment == 1].sum() / len(y_true[treatment == 1]) -
86-
y_true[treatment == 0].sum() / len(y_true[treatment == 0]))
87-
y_up_random = x_up * up_ratio_random
79+
x_baseline, y_baseline = x_actual, x_actual * y_actual[-1] / len(y_true)
80+
ax.plot(x_baseline, y_baseline, label='Random', color='black')
81+
ax.fill_between(x_actual, y_actual, y_baseline, alpha=0.2, color='b')
82+
83+
if perfect:
84+
x_perfect, y_perfect = perfect_uplift_curve(y_true, treatment)
85+
ax.plot(x_perfect, y_perfect, label='Perfect', color='Red')
8886

89-
qi_ratio_random = (y_true[treatment == 1].sum() - len(y_true[treatment == 1]) *
90-
y_true[treatment == 0].sum() / len(y_true[treatment == 0])) / len(y_true)
91-
y_qi_random = x_qi * qi_ratio_random
87+
ax.legend(loc='lower right')
88+
ax.set_title(f'Uplift curve\nuplift_auc_score={uplift_auc_score(y_true, uplift, treatment):.2f}')
89+
ax.set_xlabel('Number targeted')
90+
ax.set_ylabel('Gain: treatment - control')
9291

93-
axes[0].plot(x_up, y_up_random, label='Random', color='black')
94-
axes[0].fill_between(x_up, y_up, y_up_random, alpha=0.2, color='b')
95-
axes[1].plot(x_qi, y_qi_random, label='Random', color='black')
96-
axes[1].fill_between(x_qi, y_qi, y_qi_random, alpha=0.2, color='b')
92+
return ax
93+
94+
95+
def plot_qini_curve(y_true, uplift, treatment, random=True, perfect=True, negative_effect=True):
96+
"""Plot Qini curves from predictions.
97+
98+
Args:
99+
y_true (1d array-like): Ground truth (correct) labels.
100+
uplift (1d array-like): Predicted uplift, as returned by a model.
101+
treatment (1d array-like): Treatment labels.
102+
random (bool, default True): Draw a random curve. Default is True.
103+
perfect (bool, default False): Draw a perfect curve. Default is True.
104+
negative_effect (bool): If True, optimum Qini Curve contains the negative effects
105+
(negative uplift because of campaign). Otherwise, optimum Qini Curve will not
106+
contain the negative effects. Default is True.
107+
108+
Returns:
109+
Object that stores computed values.
110+
"""
111+
check_consistent_length(y_true, uplift, treatment)
112+
y_true, uplift, treatment = np.array(y_true), np.array(uplift), np.array(treatment)
113+
114+
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 6))
115+
116+
x_actual, y_actual = qini_curve(y_true, uplift, treatment)
117+
ax.plot(x_actual, y_actual, label='Model', color='blue')
118+
119+
if random:
120+
x_baseline, y_baseline = x_actual, x_actual * y_actual[-1] / len(y_true)
121+
ax.plot(x_baseline, y_baseline, label='Random', color='black')
122+
ax.fill_between(x_actual, y_actual, y_baseline, alpha=0.2, color='b')
97123

98124
if perfect:
99-
x_up_perfect, y_up_perfect = uplift_curve(
100-
y_true, y_true * treatment - y_true * (1 - treatment), treatment
101-
)
102-
x_qi_perfect, y_qi_perfect = qini_curve(
103-
y_true, y_true * treatment - y_true * (1 - treatment), treatment
104-
)
105-
106-
axes[0].plot(x_up_perfect, y_up_perfect, label='Perfect', color='red')
107-
axes[1].plot(x_qi_perfect, y_qi_perfect, label='Perfect', color='red')
108-
109-
axes[0].legend(loc='upper left')
110-
axes[0].set_title(f'Uplift curve: AUUC={auuc(y_true, uplift, treatment):.2f}')
111-
axes[0].set_xlabel('Number targeted')
112-
axes[0].set_ylabel('Relative gain: treatment - control')
113-
114-
axes[1].legend(loc='upper left')
115-
axes[1].set_title(f'Qini curve: AUQC={auqc(y_true, uplift, treatment):.2f}')
116-
axes[1].set_xlabel('Number targeted')
117-
axes[1].set_ylabel('Number of incremental outcome')
125+
x_perfect, y_perfect = perfect_qini_curve(y_true, treatment, negative_effect)
126+
ax.plot(x_perfect, y_perfect, label='Perfect', color='Red')
118127

119-
return axes
128+
ax.legend(loc='lower right')
129+
ax.set_title(f'Qini curve\nqini_auc_score={qini_auc_score(y_true, uplift, treatment, negative_effect):.2f}')
130+
ax.set_xlabel('Number targeted')
131+
ax.set_ylabel('Number of incremental outcome')
132+
133+
return ax
120134

121135

122136
def plot_uplift_by_percentile(y_true, uplift, treatment, strategy='overall', kind='line', bins=10):
@@ -250,19 +264,19 @@ def plot_treatment_balance_curve(uplift, treatment, random=True, winsize=0.1):
250264

251265
x_tb, y_tb = treatment_balance_curve(uplift, treatment, winsize=int(len(uplift)*winsize))
252266

253-
_, axes = plt.subplots(ncols=1, nrows=1, figsize=(14, 7))
267+
_, ax = plt.subplots(ncols=1, nrows=1, figsize=(14, 7))
254268

255-
axes.plot(x_tb, y_tb, label='Model', color='b')
269+
ax.plot(x_tb, y_tb, label='Model', color='b')
256270

257271
if random:
258272
y_tb_random = np.average(treatment) * np.ones_like(x_tb)
259273

260-
axes.plot(x_tb, y_tb_random, label='Random', color='black')
261-
axes.fill_between(x_tb, y_tb, y_tb_random, alpha=0.2, color='b')
274+
ax.plot(x_tb, y_tb_random, label='Random', color='black')
275+
ax.fill_between(x_tb, y_tb, y_tb_random, alpha=0.2, color='b')
262276

263-
axes.legend()
264-
axes.set_title('Treatment balance curve')
265-
axes.set_xlabel('Percentage targeted')
266-
axes.set_ylabel('Balance: treatment / (treatment + control)')
277+
ax.legend()
278+
ax.set_title('Treatment balance curve')
279+
ax.set_xlabel('Percentage targeted')
280+
ax.set_ylabel('Balance: treatment / (treatment + control)')
267281

268-
return axes
282+
return ax

0 commit comments

Comments
 (0)