Skip to content

Commit fa3e41e

Browse files
committed
AutoML fixes
1 parent 35c5418 commit fa3e41e

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

econml/automated_ml.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@
4141

4242
def setAutomatedMLWorkspace(create_workspace=False,
4343
create_resource_group=False, workspace_region=None, *,
44-
subscription_id=None, resource_group=None, workspace_name=None, auth=None):
44+
auth=None, subscription_id, resource_group, workspace_name):
4545
"""Set configuration file for AutomatedML actions with the EconML library. If
4646
``create_workspace`` is set true, a new workspace is created
47-
for the user. If ``create_workspace`` is set true, a new workspace is
48-
created for the user.
47+
for the user.
4948
5049
Parameters
5150
----------
@@ -68,8 +67,7 @@ def setAutomatedMLWorkspace(create_workspace=False,
6867
authentication portal in the browser.
6968
7069
subscription_id: String, required
71-
Definition of a class that will serve as the parent class of the
72-
AutomatedMLMixin. This class must inherit from _BaseDML.
70+
Azure subscription ID for the subscription under which to run the models
7371
7472
resource_group: String, required
7573
Name of resource group of workspace to be created or set.
@@ -285,12 +283,12 @@ def __init__(self, *args, **kwargs):
285283
# Loop through the kwargs and args if any of them is an AutoMLConfig file, pass them
286284
# create model and pass model into final.
287285
new_args = ()
288-
for var in args:
286+
for idx, arg in enumerate(args):
289287
# If item is an automl config, get its corresponding
290288
# AutomatedML Model and add it to new_Args
291-
if isinstance(var, EconAutoMLConfig):
292-
var = self._get_automated_ml_model(kwarg, key)
293-
new_args += (var,)
289+
if isinstance(arg, EconAutoMLConfig):
290+
arg = self._get_automated_ml_model(arg, f"arg{idx}")
291+
new_args += (arg,)
294292

295293
for key in kwargs:
296294
kwarg = kwargs[key]

econml/tests/test_automated_ml.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@
4545
AutomatedKernelDML = addAutomatedML(KernelDML)
4646
AutomatedNonParamDML = \
4747
addAutomatedML(NonParamDML)
48-
AutomatedForestDML = addAutomatedML(ForestDML)
48+
AutomatedCausalForestDML = addAutomatedML(CausalForestDML)
4949

5050
AUTOML_SETTINGS_REG = {
51-
'experiment_timeout_minutes': 1,
51+
'experiment_timeout_minutes': 15,
5252
'enable_early_stopping': True,
5353
'iteration_timeout_minutes': 1,
5454
'max_cores_per_iteration': 1,
@@ -61,7 +61,7 @@
6161
}
6262

6363
AUTOML_SETTINGS_CLF = {
64-
'experiment_timeout_minutes': 1,
64+
'experiment_timeout_minutes': 15,
6565
'enable_early_stopping': True,
6666
'iteration_timeout_minutes': 1,
6767
'max_cores_per_iteration': 1,
@@ -118,7 +118,7 @@ def automl_model_sample_weight_reg():
118118

119119

120120
@pytest.mark.automl
121-
class TestAutomatedDML(unittest.TestCase):
121+
class TestAutomatedML(unittest.TestCase):
122122

123123
@classmethod
124124
def setUpClass(cls):
@@ -134,7 +134,6 @@ def setUpClass(cls):
134134

135135
def test_nonparam(self):
136136
"""Testing the completion of the fit and effect estimation of an automated Nonparametic DML"""
137-
Y, T, X, _ = ihdp_surface_B()
138137
est = AutomatedNonParamDML(model_y=automl_model_reg(),
139138
model_t=automl_model_clf(),
140139
model_final=automl_model_sample_weight_reg(), featurizer=None,
@@ -144,7 +143,6 @@ def test_nonparam(self):
144143

145144
def test_param(self):
146145
"""Testing the completion of the fit and effect estimation of an automated Parametric DML"""
147-
Y, T, X, _ = ihdp_surface_B()
148146
est = AutomatedLinearDML(model_y=automl_model_reg(),
149147
model_t=GradientBoostingClassifier(),
150148
featurizer=None,
@@ -154,28 +152,21 @@ def test_param(self):
154152

155153
def test_forest_dml(self):
156154
"""Testing the completion of the fit and effect estimation of an AutomatedForestDML"""
157-
158-
Y, T, X, _ = ihdp_surface_B()
159-
est = AutomatedForestDML(model_y=automl_model_reg(),
160-
model_t=GradientBoostingClassifier(),
161-
discrete_treatment=True,
162-
n_estimators=1000,
163-
subsample_fr=.8,
164-
min_samples_leaf=10,
165-
min_impurity_decrease=0.001,
166-
verbose=0, min_weight_fraction_leaf=.01)
155+
est = AutomatedCausalForestDML(model_y=automl_model_reg(),
156+
model_t=GradientBoostingClassifier(),
157+
discrete_treatment=True,
158+
n_estimators=1000,
159+
max_samples=.4,
160+
min_samples_leaf=10,
161+
min_impurity_decrease=0.001,
162+
verbose=0, min_weight_fraction_leaf=.01)
167163
est.fit(Y, T, X=X)
168164
_ = est.effect(X)
169165

170-
171-
@pytest.mark.automl
172-
class TestAutomatedMetalearners(unittest.TestCase):
173-
174166
def test_TLearner(self):
175167
"""Testing the completion of the fit and effect estimation of an AutomatedTLearner"""
176168
# TLearner test
177169
# Instantiate TLearner
178-
Y, T, X, _ = ihdp_surface_B()
179170
est = AutomatedTLearner(models=automl_model_reg())
180171

181172
# Test constant and heterogeneous treatment effect, single and multi output y
@@ -188,7 +179,6 @@ def test_SLearner(self):
188179
# Test constant treatment effect with multi output Y
189180
# Test heterogeneous treatment effect
190181
# Need interactions between T and features
191-
Y, T, X, _ = ihdp_surface_B()
192182
est = AutomatedSLearner(overall_model=automl_model_reg())
193183

194184
est.fit(Y, T, X=X)
@@ -206,3 +196,20 @@ def test_DALearner(self):
206196

207197
est.fit(Y, T, X=X)
208198
_ = est.effect(X)
199+
200+
def test_positional(self):
201+
"""Test that positional arguments can be used with AutoML wrappers"""
202+
203+
class TestEstimator:
204+
def __init__(self, model_x):
205+
self.model_x = model_x
206+
207+
def fit(self, X, Y):
208+
self.model_x.fit(X, Y)
209+
return self
210+
211+
def predict(self, X):
212+
return self.model_x.predict(X)
213+
214+
AutoMLTestEstimator = addAutomatedML(TestEstimator)
215+
AutoMLTestEstimator(automl_model_reg()).fit(X, Y).predict(X)

0 commit comments

Comments
 (0)