Skip to content

Commit 336a244

Browse files
committed
Merge branch 'dev' of https://github.com/maks-sh/scikit-uplift into dev
2 parents 5d0ef61 + 27173ff commit 336a244

File tree

2 files changed

+112
-67
lines changed

2 files changed

+112
-67
lines changed

sklift/metrics/metrics.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -286,83 +286,90 @@ def uplift_at_k(y_true, uplift, treatment, strategy, k=0.3):
286286

287287

288288
def response_rate_by_percentile(y_true, uplift, treatment, group, strategy, bins=10):
289-
"""Compute response rate (target mean in the control or treatment group) at each percentile.
290-
289+
"""Compute response rate and its variance at each percentile.
290+
291+
Response rate ia a target mean in the group.
292+
291293
Args:
292294
y_true (1d array-like): Correct (true) target values.
293295
uplift (1d array-like): Predicted uplift, as returned by a model.
294296
treatment (1d array-like): Treatment labels.
295297
group (string, ['treatment', 'control']): Group type for computing response rate: treatment or control.
298+
296299
* ``'treatment'``:
297-
Values equal 1 in the treatment column.
300+
Values equal 1 in the treatment column.
301+
298302
* ``'control'``:
299-
Values equal 0 in the treatment column.
300-
strategy (string, ['overall', 'by_group']): Determines the calculating strategy.
303+
Values equal 0 in the treatment column.
304+
305+
strategy (string, ['overall', 'by_group']): Determines the calculating strategy.
306+
301307
* ``'overall'``:
302308
The first step is taking the first k observations of all test data ordered by uplift prediction
303309
(overall both groups - control and treatment) and conversions in treatment and control groups
304310
calculated only on them. Then the difference between these conversions is calculated.
311+
305312
* ``'by_group'``:
306313
Separately calculates conversions in top k observations in each group (control and treatment)
307-
sorted by uplift predictions. Then the difference between these conversions is calculated
308-
bins (int): Determines the number of bins (and relative percentile) in the test data.
309-
314+
sorted by uplift predictions. Then the difference between these conversions is calculated.
315+
316+
bins (int): Determines а number of bins (and а relative percentile) in the test data. Default is 10.
317+
310318
Returns:
311319
array: Response rate at each percentile for control or treatment group
312-
array: Variance of the response rate at each percentile
320+
array: Variance of the response rate at each percentile
313321
"""
314-
322+
315323
group_types = ['treatment', 'control']
316324
strategy_methods = ['overall', 'by_group']
317-
325+
318326
n_samples = len(y_true)
319327
check_consistent_length(y_true, uplift, treatment)
320-
328+
321329
if group not in group_types:
322330
raise ValueError(f'Response rate supports only group types in {group_types},'
323-
f' got {group}.')
331+
f' got {group}.')
324332

325333
if strategy not in strategy_methods:
326334
raise ValueError(f'Response rate supports only calculating methods in {strategy_methods},'
327335
f' got {strategy}.')
328-
336+
329337
if not isinstance(bins, int) or bins <= 0:
330-
raise ValueError(f'Bins should be positive integer.'
331-
f' Invalid value bins: {bins}')
332-
338+
raise ValueError(f'Bins should be positive integer. Invalid value bins: {bins}')
339+
333340
if bins >= n_samples:
334341
raise ValueError(f'Number of bins = {bins} should be smaller than the length of y_true {n_samples}')
335-
342+
336343
if bins == 1:
337344
warnings.warn(f'You will get the only one bin of {n_samples} samples'
338345
f' which is the length of y_true.'
339346
f'\nPlease consider using uplift_at_k function instead',
340347
UserWarning)
341-
348+
342349
y_true, uplift, treatment = np.array(y_true), np.array(uplift), np.array(treatment)
343350
order = np.argsort(uplift, kind='mergesort')[::-1]
344-
351+
345352
if group == 'treatment':
346353
trmnt_flag = 1
347354
else: # group == 'control'
348355
trmnt_flag = 0
349-
356+
350357
if strategy == 'overall':
351358
y_true_bin = np.array_split(y_true[order], bins)
352359
trmnt_bin = np.array_split(treatment[order], bins)
353-
360+
354361
group_size = np.array([len(y[trmnt == trmnt_flag]) for y, trmnt in zip(y_true_bin, trmnt_bin)])
355362
response_rate = np.array([np.mean(y[trmnt == trmnt_flag]) for y, trmnt in zip(y_true_bin, trmnt_bin)])
356363

357364
else: # strategy == 'by_group'
358365
y_bin = np.array_split(y_true[order][treatment[order] == trmnt_flag], bins)
359-
366+
360367
group_size = np.array([len(y) for y in y_bin])
361368
response_rate = np.array([np.mean(y) for y in y_bin])
362369

363370
variance = np.multiply(response_rate, np.divide((1 - response_rate), group_size))
364-
365-
return response_rate, variance
371+
372+
return response_rate, variance
366373

367374

368375
def treatment_balance_curve(uplift, treatment, winsize):

sklift/viz/base.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import matplotlib.pyplot as plt
21
import numpy as np
3-
from sklearn.utils.validation import check_consistent_length
42
import warnings
3+
import matplotlib.pyplot as plt
4+
from sklearn.utils.validation import check_consistent_length
55
from ..metrics import uplift_curve, auuc, qini_curve, auqc, response_rate_by_percentile, treatment_balance_curve
66

77

@@ -20,7 +20,7 @@ def plot_uplift_preds(trmnt_preds, ctrl_preds, log=False, bins=100):
2020
Returns:
2121
Object that stores computed values.
2222
"""
23-
# ToDo: Add k as parameter: vertical line on plots
23+
# TODO: Add k as parameter: vertical line on plots
2424
check_consistent_length(trmnt_preds, ctrl_preds)
2525

2626
if not isinstance(bins, int) or bins <= 0:
@@ -112,78 +112,116 @@ def plot_uplift_qini_curves(y_true, uplift, treatment, random=True, perfect=Fals
112112
return axes
113113

114114

115-
def plot_uplift_by_percentile(y_true, uplift, treatment, strategy, bins=10):
116-
"""Plot Uplift score at each percentile,
117-
Treatment response rate (target mean in the treatment group)
118-
and Control response rate (target mean in the control group) at each percentile.
119-
115+
def plot_uplift_by_percentile(y_true, uplift, treatment, strategy, kind='line', bins=10):
116+
"""Plot uplift score, treatment response rate and control response rate at each percentile.
117+
118+
Treatment response rate ia a target mean in the treatment group.
119+
Control response rate is a target mean in the control group.
120+
Uplift score is a difference between treatment response rate and control response rate.
121+
120122
Args:
121123
y_true (1d array-like): Correct (true) target values.
122124
uplift (1d array-like): Predicted uplift, as returned by a model.
123125
treatment (1d array-like): Treatment labels.
124-
strategy (string, ['overall', 'by_group']): Determines the calculating strategy. Defaults to 'first'.
126+
strategy (string, ['overall', 'by_group']): Determines the calculating strategy.
127+
125128
* ``'overall'``:
126129
The first step is taking the first k observations of all test data ordered by uplift prediction
127130
(overall both groups - control and treatment) and conversions in treatment and control groups
128131
calculated only on them. Then the difference between these conversions is calculated.
132+
129133
* ``'by_group'``:
130134
Separately calculates conversions in top k observations in each group (control and treatment)
131-
sorted by uplift predictions. Then the difference between these conversions is calculated
132-
bins (int): Determines the number of bins (and relative percentile) in the test data.
133-
135+
sorted by uplift predictions. Then the difference between these conversions is calculated.
136+
137+
kind (string, ['line', 'bar']): The type of plot to draw. Default is 'line'.
138+
139+
* ``'line'``:
140+
Generates a line plot.
141+
142+
* ``'bar'``:
143+
Generates a traditional bar-style plot.
144+
145+
bins (int): Determines а number of bins (and а relative percentile) in the test data. Default is 10.
146+
134147
Returns:
135148
Object that stores computed values.
136149
"""
137-
150+
138151
strategy_methods = ['overall', 'by_group']
139-
152+
kind_methods = ['line', 'bar']
153+
140154
n_samples = len(y_true)
141155
check_consistent_length(y_true, uplift, treatment)
142-
156+
143157
if strategy not in strategy_methods:
144158
raise ValueError(f'Response rate supports only calculating methods in {strategy_methods},'
145159
f' got {strategy}.')
146-
160+
161+
if kind not in kind_methods:
162+
raise ValueError(f'Function supports only types of plots in {kind_methods},'
163+
f' got {kind}.')
164+
147165
if not isinstance(bins, int) or bins <= 0:
148166
raise ValueError(f'Bins should be positive integer. Invalid value bins: {bins}')
149167

150168
if bins >= n_samples:
151169
raise ValueError(f'Number of bins = {bins} should be smaller than the length of y_true {n_samples}')
152-
153-
if bins == 1:
154-
warnings.warn(f'You will get the only one bin of {n_samples} samples'
155-
f' which is the length of y_true.'
156-
f'\nPlease consider using uplift_at_k function instead',
157-
UserWarning)
158-
170+
159171
rspns_rate_trmnt, var_trmnt = response_rate_by_percentile(y_true, uplift,
160172
treatment, group='treatment',
161173
strategy=strategy, bins=bins)
162-
174+
163175
rspns_rate_ctrl, var_ctrl = response_rate_by_percentile(y_true, uplift,
164176
treatment, group='control',
165177
strategy=strategy, bins=bins)
166178

167179
uplift_score, uplift_variance = np.subtract(rspns_rate_trmnt, rspns_rate_ctrl), np.add(var_trmnt, var_ctrl)
168-
180+
169181
percentiles = [p * 100 / bins for p in range(1, bins + 1)]
170-
171-
_, axes = plt.subplots(ncols=1, nrows=1, figsize=(8, 6))
172-
173-
axes.errorbar(percentiles, uplift_score, yerr=np.sqrt(uplift_variance),
174-
linewidth=2, color='red', label='uplift')
175-
axes.errorbar(percentiles, rspns_rate_trmnt, yerr=np.sqrt(var_trmnt),
176-
linewidth=2, color='forestgreen', label='treatment\nresponse rate')
177-
axes.errorbar(percentiles, rspns_rate_ctrl, yerr=np.sqrt(var_ctrl),
178-
linewidth=2, color='orange', label='control\nresponse rate')
179-
axes.fill_between(percentiles, rspns_rate_ctrl, rspns_rate_trmnt, alpha=0.1, color='red')
180-
181-
axes.set_xticks(percentiles)
182-
axes.legend(loc='upper right')
183-
axes.set_title('Uplift by percentile')
184-
axes.set_xlabel('Percentile')
185-
axes.set_ylabel('Uplift = treatment response rate - control response rate')
186-
182+
183+
if kind == 'line':
184+
_, axes = plt.subplots(ncols=1, nrows=1, figsize=(8, 6))
185+
axes.errorbar(percentiles, uplift_score, yerr=np.sqrt(uplift_variance),
186+
linewidth=2, color='red', label='uplift')
187+
axes.errorbar(percentiles, rspns_rate_trmnt, yerr=np.sqrt(var_trmnt),
188+
linewidth=2, color='forestgreen', label='treatment\nresponse rate')
189+
axes.errorbar(percentiles, rspns_rate_ctrl, yerr=np.sqrt(var_ctrl),
190+
linewidth=2, color='orange', label='control\nresponse rate')
191+
axes.fill_between(percentiles, rspns_rate_ctrl, rspns_rate_trmnt, alpha=0.1, color='red')
192+
193+
if np.amin(uplift_score) < 0:
194+
axes.axhline(y=0, color='black', linewidth=1)
195+
axes.set_xticks(percentiles)
196+
axes.legend(loc='upper right')
197+
axes.set_title('Uplift by percentile')
198+
axes.set_xlabel('Percentile')
199+
axes.set_ylabel('Uplift = treatment response rate - control response rate')
200+
201+
else: # kind == 'bar'
202+
delta = percentiles[0]
203+
fig, axes = plt.subplots(ncols=1, nrows=2, figsize=(8, 6), sharex=True, sharey=True)
204+
fig.text(0.04, 0.5, 'Uplift = treatment response rate - control response rate',
205+
va='center', ha='center', rotation='vertical')
206+
207+
axes[0].bar(np.array(percentiles), uplift_score, delta / 1.5,
208+
yerr=np.sqrt(uplift_variance), color='red', label='uplift')
209+
axes[1].bar(np.array(percentiles) - delta / 6, rspns_rate_trmnt, delta / 3,
210+
yerr=np.sqrt(var_trmnt), color='forestgreen', label='treatment\nresponse rate')
211+
axes[1].bar(np.array(percentiles) + delta / 6, rspns_rate_ctrl, delta / 3,
212+
yerr=np.sqrt(var_ctrl), color='orange', label='control\nresponse rate')
213+
214+
axes[0].legend(loc='upper right')
215+
axes[0].tick_params(axis='x', bottom=False)
216+
axes[0].axhline(y=0, color='black', linewidth=1)
217+
axes[0].set_title('Uplift by percentile')
218+
219+
axes[1].set_xticks(percentiles)
220+
axes[1].legend(loc='upper right')
221+
axes[1].axhline(y=0, color='black', linewidth=1)
222+
axes[1].set_xlabel('Percentile')
223+
axes[1].set_title('Response rate by percentile')
224+
187225
return axes
188226

189227

0 commit comments

Comments
 (0)