1- import numpy as np
21import matplotlib .pyplot as plt
2+ import numpy as np
33from sklearn .utils .validation import check_consistent_length
4+
45from ..metrics import (
5- uplift_curve ,
6- auuc ,
7- qini_curve ,
8- auqc ,
9- response_rate_by_percentile ,
10- uplift_by_percentile ,
6+ uplift_curve , perfect_uplift_curve , uplift_auc_score ,
7+ qini_curve , perfect_qini_curve , qini_auc_score ,
118 treatment_balance_curve
129)
1310
@@ -18,11 +15,11 @@ def plot_uplift_preds(trmnt_preds, ctrl_preds, log=False, bins=100):
1815 Args:
1916 trmnt_preds (1d array-like): Predictions for all observations if they are treatment.
2017 ctrl_preds (1d array-like): Predictions for all observations if they are control.
21- log (bool, default False): Logarithm of source samples.
18+ log (bool, default False): Logarithm of source samples. Default is False.
2219 bins (integer or sequence, default 100): Number of histogram bins to be used.
2320 If an integer is given, bins + 1 bin edges are calculated and returned.
2421 If bins is a sequence, gives bin edges, including left edge of first bin and right edge of last bin.
25- In this case, bins is returned unmodified.
22+ In this case, bins is returned unmodified. Default is 100.
2623
2724 Returns:
2825 Object that stores computed values.
@@ -57,66 +54,83 @@ def plot_uplift_preds(trmnt_preds, ctrl_preds, log=False, bins=100):
5754 return axes
5855
5956
60- def plot_uplift_qini_curves (y_true , uplift , treatment , random = True , perfect = False ):
61- """Plot Uplift and Qini curves .
57+ def plot_uplift_curve (y_true , uplift , treatment , random = True , perfect = True ):
58+ """Plot Uplift curves from predictions .
6259
6360 Args:
6461 y_true (1d array-like): Ground truth (correct) labels.
6562 uplift (1d array-like): Predicted uplift, as returned by a model.
6663 treatment (1d array-like): Treatment labels.
67- random (bool, default True): Draw a random curve.
68- perfect (bool, default False): Draw a perfect curve.
64+ random (bool, default True): Draw a random curve. Default is True.
65+ perfect (bool, default False): Draw a perfect curve. Default is True.
6966
7067 Returns:
7168 Object that stores computed values.
7269 """
7370 check_consistent_length (y_true , uplift , treatment )
7471 y_true , uplift , treatment = np .array (y_true ), np .array (uplift ), np .array (treatment )
7572
76- x_up , y_up = uplift_curve (y_true , uplift , treatment )
77- x_qi , y_qi = qini_curve (y_true , uplift , treatment )
73+ fig , ax = plt .subplots (ncols = 1 , nrows = 1 , figsize = (8 , 6 ))
7874
79- fig , axes = plt .subplots (ncols = 2 , nrows = 1 , figsize = (14 , 7 ))
80-
81- axes [0 ].plot (x_up , y_up , label = 'Model' , color = 'b' )
82- axes [1 ].plot (x_qi , y_qi , label = 'Model' , color = 'b' )
75+ x_actual , y_actual = uplift_curve (y_true , uplift , treatment )
76+ ax .plot (x_actual , y_actual , label = 'Model' , color = 'blue' )
8377
8478 if random :
85- up_ratio_random = (y_true [treatment == 1 ].sum () / len (y_true [treatment == 1 ]) -
86- y_true [treatment == 0 ].sum () / len (y_true [treatment == 0 ]))
87- y_up_random = x_up * up_ratio_random
79+ x_baseline , y_baseline = x_actual , x_actual * y_actual [- 1 ] / len (y_true )
80+ ax .plot (x_baseline , y_baseline , label = 'Random' , color = 'black' )
81+ ax .fill_between (x_actual , y_actual , y_baseline , alpha = 0.2 , color = 'b' )
82+
83+ if perfect :
84+ x_perfect , y_perfect = perfect_uplift_curve (y_true , treatment )
85+ ax .plot (x_perfect , y_perfect , label = 'Perfect' , color = 'Red' )
8886
89- qi_ratio_random = (y_true [treatment == 1 ].sum () - len (y_true [treatment == 1 ]) *
90- y_true [treatment == 0 ].sum () / len (y_true [treatment == 0 ])) / len (y_true )
91- y_qi_random = x_qi * qi_ratio_random
87+ ax .legend (loc = 'lower right' )
88+ ax .set_title (f'Uplift curve\n uplift_auc_score={ uplift_auc_score (y_true , uplift , treatment ):.2f} ' )
89+ ax .set_xlabel ('Number targeted' )
90+ ax .set_ylabel ('Gain: treatment - control' )
9291
93- axes [0 ].plot (x_up , y_up_random , label = 'Random' , color = 'black' )
94- axes [0 ].fill_between (x_up , y_up , y_up_random , alpha = 0.2 , color = 'b' )
95- axes [1 ].plot (x_qi , y_qi_random , label = 'Random' , color = 'black' )
96- axes [1 ].fill_between (x_qi , y_qi , y_qi_random , alpha = 0.2 , color = 'b' )
92+ return ax
93+
94+
95+ def plot_qini_curve (y_true , uplift , treatment , random = True , perfect = True , negative_effect = True ):
96+ """Plot Qini curves from predictions.
97+
98+ Args:
99+ y_true (1d array-like): Ground truth (correct) labels.
100+ uplift (1d array-like): Predicted uplift, as returned by a model.
101+ treatment (1d array-like): Treatment labels.
102+ random (bool, default True): Draw a random curve. Default is True.
103+ perfect (bool, default False): Draw a perfect curve. Default is True.
104+ negative_effect (bool): If True, optimum Qini Curve contains the negative effects
105+ (negative uplift because of campaign). Otherwise, optimum Qini Curve will not
106+ contain the negative effects. Default is True.
107+
108+ Returns:
109+ Object that stores computed values.
110+ """
111+ check_consistent_length (y_true , uplift , treatment )
112+ y_true , uplift , treatment = np .array (y_true ), np .array (uplift ), np .array (treatment )
113+
114+ fig , ax = plt .subplots (ncols = 1 , nrows = 1 , figsize = (8 , 6 ))
115+
116+ x_actual , y_actual = qini_curve (y_true , uplift , treatment )
117+ ax .plot (x_actual , y_actual , label = 'Model' , color = 'blue' )
118+
119+ if random :
120+ x_baseline , y_baseline = x_actual , x_actual * y_actual [- 1 ] / len (y_true )
121+ ax .plot (x_baseline , y_baseline , label = 'Random' , color = 'black' )
122+ ax .fill_between (x_actual , y_actual , y_baseline , alpha = 0.2 , color = 'b' )
97123
98124 if perfect :
99- x_up_perfect , y_up_perfect = uplift_curve (
100- y_true , y_true * treatment - y_true * (1 - treatment ), treatment
101- )
102- x_qi_perfect , y_qi_perfect = qini_curve (
103- y_true , y_true * treatment - y_true * (1 - treatment ), treatment
104- )
105-
106- axes [0 ].plot (x_up_perfect , y_up_perfect , label = 'Perfect' , color = 'red' )
107- axes [1 ].plot (x_qi_perfect , y_qi_perfect , label = 'Perfect' , color = 'red' )
108-
109- axes [0 ].legend (loc = 'upper left' )
110- axes [0 ].set_title (f'Uplift curve: AUUC={ auuc (y_true , uplift , treatment ):.2f} ' )
111- axes [0 ].set_xlabel ('Number targeted' )
112- axes [0 ].set_ylabel ('Relative gain: treatment - control' )
113-
114- axes [1 ].legend (loc = 'upper left' )
115- axes [1 ].set_title (f'Qini curve: AUQC={ auqc (y_true , uplift , treatment ):.2f} ' )
116- axes [1 ].set_xlabel ('Number targeted' )
117- axes [1 ].set_ylabel ('Number of incremental outcome' )
125+ x_perfect , y_perfect = perfect_qini_curve (y_true , treatment , negative_effect )
126+ ax .plot (x_perfect , y_perfect , label = 'Perfect' , color = 'Red' )
118127
119- return axes
128+ ax .legend (loc = 'lower right' )
129+ ax .set_title (f'Qini curve\n qini_auc_score={ qini_auc_score (y_true , uplift , treatment , negative_effect ):.2f} ' )
130+ ax .set_xlabel ('Number targeted' )
131+ ax .set_ylabel ('Number of incremental outcome' )
132+
133+ return ax
120134
121135
122136def plot_uplift_by_percentile (y_true , uplift , treatment , strategy = 'overall' , kind = 'line' , bins = 10 ):
@@ -250,19 +264,19 @@ def plot_treatment_balance_curve(uplift, treatment, random=True, winsize=0.1):
250264
251265 x_tb , y_tb = treatment_balance_curve (uplift , treatment , winsize = int (len (uplift )* winsize ))
252266
253- _ , axes = plt .subplots (ncols = 1 , nrows = 1 , figsize = (14 , 7 ))
267+ _ , ax = plt .subplots (ncols = 1 , nrows = 1 , figsize = (14 , 7 ))
254268
255- axes .plot (x_tb , y_tb , label = 'Model' , color = 'b' )
269+ ax .plot (x_tb , y_tb , label = 'Model' , color = 'b' )
256270
257271 if random :
258272 y_tb_random = np .average (treatment ) * np .ones_like (x_tb )
259273
260- axes .plot (x_tb , y_tb_random , label = 'Random' , color = 'black' )
261- axes .fill_between (x_tb , y_tb , y_tb_random , alpha = 0.2 , color = 'b' )
274+ ax .plot (x_tb , y_tb_random , label = 'Random' , color = 'black' )
275+ ax .fill_between (x_tb , y_tb , y_tb_random , alpha = 0.2 , color = 'b' )
262276
263- axes .legend ()
264- axes .set_title ('Treatment balance curve' )
265- axes .set_xlabel ('Percentage targeted' )
266- axes .set_ylabel ('Balance: treatment / (treatment + control)' )
277+ ax .legend ()
278+ ax .set_title ('Treatment balance curve' )
279+ ax .set_xlabel ('Percentage targeted' )
280+ ax .set_ylabel ('Balance: treatment / (treatment + control)' )
267281
268- return axes
282+ return ax
0 commit comments