55import scipy .stats as st
66from sklearn .ensemble import RandomForestClassifier , GradientBoostingRegressor
77
8- from econml .validate .drtester import DRtester
8+ from econml .validate .drtester import DRTester
99from econml .dml import DML
1010
1111
@@ -70,7 +70,7 @@ def test_multi(self):
7070 ).fit (Y = Ytrain , T = Dtrain , X = Xtrain )
7171
7272 # test the DR outcome difference
73- my_dr_tester = DRtester (
73+ my_dr_tester = DRTester (
7474 model_regression = reg_y ,
7575 model_propensity = reg_t ,
7676 cate = cate
@@ -123,7 +123,7 @@ def test_binary(self):
123123 ).fit (Y = Ytrain , T = Dtrain , X = Xtrain )
124124
125125 # test the DR outcome difference
126- my_dr_tester = DRtester (
126+ my_dr_tester = DRTester (
127127 model_regression = reg_y ,
128128 model_propensity = reg_t ,
129129 cate = cate
@@ -148,8 +148,8 @@ def test_binary(self):
148148 self .assertRaises (ValueError , res .plot_toc , k )
149149 else : # real treatment, k = 1
150150 self .assertTrue (res .plot_cal (k ) is not None )
151- self .assertTrue (res .plot_qini (k ) is not None )
152- self .assertTrue (res .plot_toc (k ) is not None )
151+ self .assertTrue (res .plot_qini (k , 'ucb2' ) is not None )
152+ self .assertTrue (res .plot_toc (k , 'ucb1' ) is not None )
153153
154154 self .assertLess (res_df .blp_pval .values [0 ], 0.05 ) # heterogeneity
155155 self .assertGreater (res_df .cal_r_squared .values [0 ], 0 ) # good R2
@@ -171,7 +171,7 @@ def test_nuisance_val_fit(self):
171171 ).fit (Y = Ytrain , T = Dtrain , X = Xtrain )
172172
173173 # test the DR outcome difference
174- my_dr_tester = DRtester (
174+ my_dr_tester = DRTester (
175175 model_regression = reg_y ,
176176 model_propensity = reg_t ,
177177 cate = cate
@@ -193,8 +193,8 @@ def test_nuisance_val_fit(self):
193193 for kwargs in [{}, {'Xval' : Xval }]:
194194 with self .assertRaises (Exception ) as exc :
195195 my_dr_tester .evaluate_cal (kwargs )
196- self .assertTrue (
197- str (exc .exception ) == "Must fit nuisance models on training sample data to use calibration test"
196+ self .assertEqual (
197+ str (exc .exception ), "Must fit nuisance models on training sample data to use calibration test"
198198 )
199199
200200 def test_exceptions (self ):
@@ -212,7 +212,7 @@ def test_exceptions(self):
212212 ).fit (Y = Ytrain , T = Dtrain , X = Xtrain )
213213
214214 # test the DR outcome difference
215- my_dr_tester = DRtester (
215+ my_dr_tester = DRTester (
216216 model_regression = reg_y ,
217217 model_propensity = reg_t ,
218218 cate = cate
@@ -223,11 +223,11 @@ def test_exceptions(self):
223223 with self .assertRaises (Exception ) as exc :
224224 func ()
225225 if func .__name__ == 'evaluate_cal' :
226- self .assertTrue (
227- str (exc .exception ) == "Must fit nuisance models on training sample data to use calibration test"
226+ self .assertEqual (
227+ str (exc .exception ), "Must fit nuisance models on training sample data to use calibration test"
228228 )
229229 else :
230- self .assertTrue (str (exc .exception ) == "Must fit nuisances before evaluating" )
230+ self .assertEqual (str (exc .exception ), "Must fit nuisances before evaluating" )
231231
232232 my_dr_tester = my_dr_tester .fit_nuisance (
233233 Xval , Dval , Yval , Xtrain , Dtrain , Ytrain
@@ -242,12 +242,12 @@ def test_exceptions(self):
242242 with self .assertRaises (Exception ) as exc :
243243 func ()
244244 if func .__name__ == 'evaluate_blp' :
245- self .assertTrue (
246- str (exc .exception ) == "CATE predictions not yet calculated - must provide Xval"
245+ self .assertEqual (
246+ str (exc .exception ), "CATE predictions not yet calculated - must provide Xval"
247247 )
248248 else :
249- self .assertTrue (str (exc .exception ) ==
250- "CATE predictions not yet calculated - must provide both Xval, Xtrain" )
249+ self .assertEqual (str (exc .exception ),
250+ "CATE predictions not yet calculated - must provide both Xval, Xtrain" )
251251
252252 for func in [
253253 my_dr_tester .evaluate_cal ,
@@ -256,19 +256,19 @@ def test_exceptions(self):
256256 ]:
257257 with self .assertRaises (Exception ) as exc :
258258 func (Xval = Xval )
259- self .assertTrue (
260- str (exc .exception ) == "CATE predictions not yet calculated - must provide both Xval, Xtrain" )
259+ self .assertEqual (
260+ str (exc .exception ), "CATE predictions not yet calculated - must provide both Xval, Xtrain" )
261261
262262 cal_res = my_dr_tester .evaluate_cal (Xval , Xtrain )
263263 self .assertGreater (cal_res .cal_r_squared [0 ], 0 ) # good R2
264264
265265 with self .assertRaises (Exception ) as exc :
266266 my_dr_tester .evaluate_uplift (metric = 'blah' )
267- self .assertTrue (
268- str (exc .exception ) == "Unsupported metric - must be one of ['toc', 'qini']"
267+ self .assertEqual (
268+ str (exc .exception ), "Unsupported metric 'blah' - must be one of ['toc', 'qini']"
269269 )
270270
271- my_dr_tester = DRtester (
271+ my_dr_tester = DRTester (
272272 model_regression = reg_y ,
273273 model_propensity = reg_t ,
274274 cate = cate
@@ -278,5 +278,11 @@ def test_exceptions(self):
278278 qini_res = my_dr_tester .evaluate_uplift (Xval , Xtrain )
279279 self .assertLess (qini_res .pvals [0 ], 0.05 )
280280
281+ with self .assertRaises (Exception ) as exc :
282+ qini_res .plot_uplift (tmt = 1 , err_type = 'blah' )
283+ self .assertEqual (
284+ str (exc .exception ), "Invalid error type 'blah'; must be one of [None, 'ucb2', 'ucb1']"
285+ )
286+
281287 autoc_res = my_dr_tester .evaluate_uplift (Xval , Xtrain , metric = 'toc' )
282288 self .assertLess (autoc_res .pvals [0 ], 0.05 )
0 commit comments