4545 AutomatedKernelDML = addAutomatedML (KernelDML )
4646 AutomatedNonParamDML = \
4747 addAutomatedML (NonParamDML )
48- AutomatedForestDML = addAutomatedML (ForestDML )
48+ AutomatedCausalForestDML = addAutomatedML (CausalForestDML )
4949
5050 AUTOML_SETTINGS_REG = {
51- 'experiment_timeout_minutes' : 1 ,
51+ 'experiment_timeout_minutes' : 15 ,
5252 'enable_early_stopping' : True ,
5353 'iteration_timeout_minutes' : 1 ,
5454 'max_cores_per_iteration' : 1 ,
6161 }
6262
6363 AUTOML_SETTINGS_CLF = {
64- 'experiment_timeout_minutes' : 1 ,
64+ 'experiment_timeout_minutes' : 15 ,
6565 'enable_early_stopping' : True ,
6666 'iteration_timeout_minutes' : 1 ,
6767 'max_cores_per_iteration' : 1 ,
@@ -118,7 +118,7 @@ def automl_model_sample_weight_reg():
118118
119119
120120@pytest .mark .automl
121- class TestAutomatedDML (unittest .TestCase ):
121+ class TestAutomatedML (unittest .TestCase ):
122122
123123 @classmethod
124124 def setUpClass (cls ):
@@ -134,7 +134,6 @@ def setUpClass(cls):
134134
135135 def test_nonparam (self ):
136136 """Testing the completion of the fit and effect estimation of an automated Nonparametic DML"""
137- Y , T , X , _ = ihdp_surface_B ()
138137 est = AutomatedNonParamDML (model_y = automl_model_reg (),
139138 model_t = automl_model_clf (),
140139 model_final = automl_model_sample_weight_reg (), featurizer = None ,
@@ -144,7 +143,6 @@ def test_nonparam(self):
144143
145144 def test_param (self ):
146145 """Testing the completion of the fit and effect estimation of an automated Parametric DML"""
147- Y , T , X , _ = ihdp_surface_B ()
148146 est = AutomatedLinearDML (model_y = automl_model_reg (),
149147 model_t = GradientBoostingClassifier (),
150148 featurizer = None ,
@@ -154,28 +152,21 @@ def test_param(self):
154152
155153 def test_forest_dml (self ):
156154 """Testing the completion of the fit and effect estimation of an AutomatedForestDML"""
157-
158- Y , T , X , _ = ihdp_surface_B ()
159- est = AutomatedForestDML (model_y = automl_model_reg (),
160- model_t = GradientBoostingClassifier (),
161- discrete_treatment = True ,
162- n_estimators = 1000 ,
163- subsample_fr = .8 ,
164- min_samples_leaf = 10 ,
165- min_impurity_decrease = 0.001 ,
166- verbose = 0 , min_weight_fraction_leaf = .01 )
155+ est = AutomatedCausalForestDML (model_y = automl_model_reg (),
156+ model_t = GradientBoostingClassifier (),
157+ discrete_treatment = True ,
158+ n_estimators = 1000 ,
159+ max_samples = .4 ,
160+ min_samples_leaf = 10 ,
161+ min_impurity_decrease = 0.001 ,
162+ verbose = 0 , min_weight_fraction_leaf = .01 )
167163 est .fit (Y , T , X = X )
168164 _ = est .effect (X )
169165
170-
171- @pytest .mark .automl
172- class TestAutomatedMetalearners (unittest .TestCase ):
173-
174166 def test_TLearner (self ):
175167 """Testing the completion of the fit and effect estimation of an AutomatedTLearner"""
176168 # TLearner test
177169 # Instantiate TLearner
178- Y , T , X , _ = ihdp_surface_B ()
179170 est = AutomatedTLearner (models = automl_model_reg ())
180171
181172 # Test constant and heterogeneous treatment effect, single and multi output y
@@ -188,7 +179,6 @@ def test_SLearner(self):
188179 # Test constant treatment effect with multi output Y
189180 # Test heterogeneous treatment effect
190181 # Need interactions between T and features
191- Y , T , X , _ = ihdp_surface_B ()
192182 est = AutomatedSLearner (overall_model = automl_model_reg ())
193183
194184 est .fit (Y , T , X = X )
@@ -206,3 +196,20 @@ def test_DALearner(self):
206196
207197 est .fit (Y , T , X = X )
208198 _ = est .effect (X )
199+
200+ def test_positional (self ):
201+ """Test that positional arguments can be used with AutoML wrappers"""
202+
203+ class TestEstimator :
204+ def __init__ (self , model_x ):
205+ self .model_x = model_x
206+
207+ def fit (self , X , Y ):
208+ self .model_x .fit (X , Y )
209+ return self
210+
211+ def predict (self , X ):
212+ return self .model_x .predict (X )
213+
214+ AutoMLTestEstimator = addAutomatedML (TestEstimator )
215+ AutoMLTestEstimator (automl_model_reg ()).fit (X , Y ).predict (X )
0 commit comments