Skip to content

Commit 27d3101

Browse files
authored
CATE validation - uplift uniform confidence bands (#840)
Add support for multiplier bootstrap uniform confidence band error bars for uplift curves
1 parent ed4fe33 commit 27d3101

File tree

7 files changed

+381
-149
lines changed

7 files changed

+381
-149
lines changed

doc/reference.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,20 @@ CATE Interpreters
147147
econml.cate_interpreter.SingleTreeCateInterpreter
148148
econml.cate_interpreter.SingleTreePolicyInterpreter
149149

150+
.. _validation_api:
151+
152+
CATE Validation
153+
---------------
154+
155+
.. autosummary::
156+
:toctree: _autosummary
157+
158+
econml.validate.DRTester
159+
econml.validate.BLPEvaluationResults
160+
econml.validate.CalibrationEvaluationResults
161+
econml.validate.UpliftEvaluationResults
162+
econml.validate.EvaluationResults
163+
150164
.. _scorers_api:
151165

152166
CATE Scorers

econml/tests/test_drtester.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import scipy.stats as st
66
from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor
77

8-
from econml.validate.drtester import DRtester
8+
from econml.validate.drtester import DRTester
99
from econml.dml import DML
1010

1111

@@ -70,7 +70,7 @@ def test_multi(self):
7070
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)
7171

7272
# test the DR outcome difference
73-
my_dr_tester = DRtester(
73+
my_dr_tester = DRTester(
7474
model_regression=reg_y,
7575
model_propensity=reg_t,
7676
cate=cate
@@ -123,7 +123,7 @@ def test_binary(self):
123123
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)
124124

125125
# test the DR outcome difference
126-
my_dr_tester = DRtester(
126+
my_dr_tester = DRTester(
127127
model_regression=reg_y,
128128
model_propensity=reg_t,
129129
cate=cate
@@ -148,8 +148,8 @@ def test_binary(self):
148148
self.assertRaises(ValueError, res.plot_toc, k)
149149
else: # real treatment, k = 1
150150
self.assertTrue(res.plot_cal(k) is not None)
151-
self.assertTrue(res.plot_qini(k) is not None)
152-
self.assertTrue(res.plot_toc(k) is not None)
151+
self.assertTrue(res.plot_qini(k, 'ucb2') is not None)
152+
self.assertTrue(res.plot_toc(k, 'ucb1') is not None)
153153

154154
self.assertLess(res_df.blp_pval.values[0], 0.05) # heterogeneity
155155
self.assertGreater(res_df.cal_r_squared.values[0], 0) # good R2
@@ -171,7 +171,7 @@ def test_nuisance_val_fit(self):
171171
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)
172172

173173
# test the DR outcome difference
174-
my_dr_tester = DRtester(
174+
my_dr_tester = DRTester(
175175
model_regression=reg_y,
176176
model_propensity=reg_t,
177177
cate=cate
@@ -193,8 +193,8 @@ def test_nuisance_val_fit(self):
193193
for kwargs in [{}, {'Xval': Xval}]:
194194
with self.assertRaises(Exception) as exc:
195195
my_dr_tester.evaluate_cal(kwargs)
196-
self.assertTrue(
197-
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
196+
self.assertEqual(
197+
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
198198
)
199199

200200
def test_exceptions(self):
@@ -212,7 +212,7 @@ def test_exceptions(self):
212212
).fit(Y=Ytrain, T=Dtrain, X=Xtrain)
213213

214214
# test the DR outcome difference
215-
my_dr_tester = DRtester(
215+
my_dr_tester = DRTester(
216216
model_regression=reg_y,
217217
model_propensity=reg_t,
218218
cate=cate
@@ -223,11 +223,11 @@ def test_exceptions(self):
223223
with self.assertRaises(Exception) as exc:
224224
func()
225225
if func.__name__ == 'evaluate_cal':
226-
self.assertTrue(
227-
str(exc.exception) == "Must fit nuisance models on training sample data to use calibration test"
226+
self.assertEqual(
227+
str(exc.exception), "Must fit nuisance models on training sample data to use calibration test"
228228
)
229229
else:
230-
self.assertTrue(str(exc.exception) == "Must fit nuisances before evaluating")
230+
self.assertEqual(str(exc.exception), "Must fit nuisances before evaluating")
231231

232232
my_dr_tester = my_dr_tester.fit_nuisance(
233233
Xval, Dval, Yval, Xtrain, Dtrain, Ytrain
@@ -242,12 +242,12 @@ def test_exceptions(self):
242242
with self.assertRaises(Exception) as exc:
243243
func()
244244
if func.__name__ == 'evaluate_blp':
245-
self.assertTrue(
246-
str(exc.exception) == "CATE predictions not yet calculated - must provide Xval"
245+
self.assertEqual(
246+
str(exc.exception), "CATE predictions not yet calculated - must provide Xval"
247247
)
248248
else:
249-
self.assertTrue(str(exc.exception) ==
250-
"CATE predictions not yet calculated - must provide both Xval, Xtrain")
249+
self.assertEqual(str(exc.exception),
250+
"CATE predictions not yet calculated - must provide both Xval, Xtrain")
251251

252252
for func in [
253253
my_dr_tester.evaluate_cal,
@@ -256,19 +256,19 @@ def test_exceptions(self):
256256
]:
257257
with self.assertRaises(Exception) as exc:
258258
func(Xval=Xval)
259-
self.assertTrue(
260-
str(exc.exception) == "CATE predictions not yet calculated - must provide both Xval, Xtrain")
259+
self.assertEqual(
260+
str(exc.exception), "CATE predictions not yet calculated - must provide both Xval, Xtrain")
261261

262262
cal_res = my_dr_tester.evaluate_cal(Xval, Xtrain)
263263
self.assertGreater(cal_res.cal_r_squared[0], 0) # good R2
264264

265265
with self.assertRaises(Exception) as exc:
266266
my_dr_tester.evaluate_uplift(metric='blah')
267-
self.assertTrue(
268-
str(exc.exception) == "Unsupported metric - must be one of ['toc', 'qini']"
267+
self.assertEqual(
268+
str(exc.exception), "Unsupported metric 'blah' - must be one of ['toc', 'qini']"
269269
)
270270

271-
my_dr_tester = DRtester(
271+
my_dr_tester = DRTester(
272272
model_regression=reg_y,
273273
model_propensity=reg_t,
274274
cate=cate
@@ -278,5 +278,11 @@ def test_exceptions(self):
278278
qini_res = my_dr_tester.evaluate_uplift(Xval, Xtrain)
279279
self.assertLess(qini_res.pvals[0], 0.05)
280280

281+
with self.assertRaises(Exception) as exc:
282+
qini_res.plot_uplift(tmt=1, err_type='blah')
283+
self.assertEqual(
284+
str(exc.exception), "Invalid error type 'blah'; must be one of [None, 'ucb2', 'ucb1']"
285+
)
286+
281287
autoc_res = my_dr_tester.evaluate_uplift(Xval, Xtrain, metric='toc')
282288
self.assertLess(autoc_res.pvals[0], 0.05)

econml/validate/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
A suite of validation methods for CATE models.
66
"""
77

8-
from .drtester import DRtester
8+
from .drtester import DRTester
9+
from .results import BLPEvaluationResults, CalibrationEvaluationResults, UpliftEvaluationResults, EvaluationResults
910

1011

11-
__all__ = ['DRtester']
12+
__all__ = ['DRTester',
13+
'BLPEvaluationResults', 'CalibrationEvaluationResults', 'UpliftEvaluationResults', 'EvaluationResults']

0 commit comments

Comments
 (0)