Skip to content

Commit d48c2e4

Browse files
committed
Update estimator factories
1 parent e0e52fa commit d48c2e4

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

tests/estimator_factory_test.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
import pytest
24

35
from ylearn.estimator_model import ESTIMATOR_FACTORIES
@@ -33,7 +35,7 @@ def test_Xb_Yc(key):
3335

3436

3537
@if_policy_tree_ready
36-
@pytest.mark.parametrize('key', ['tree', ])
38+
@pytest.mark.parametrize('key', ['tree', 'grf'])
3739
def test_Xb_Yc_tree(key):
3840
data, test_data, outcome, treatment, adjustment, covariate = _dgp.generate_data_x1b_y1()
3941

@@ -42,7 +44,12 @@ def test_Xb_Yc_tree(key):
4244
adjustment=adjustment, covariate=covariate, random_state=123)
4345
assert est is not None
4446

45-
est.fit(data, outcome, treatment, adjustment=adjustment, covariate=covariate, n_jobs=1)
47+
fit_options = {}
48+
sig = inspect.signature(est.fit)
49+
if 'n_jobs' in sig.parameters.keys():
50+
fit_options['n_jobs'] = 1
51+
52+
est.fit(data, outcome, treatment, adjustment=adjustment, covariate=covariate, **fit_options)
4653
effect = est.estimate(test_data)
4754
assert effect.shape[0] == len(test_data)
4855

@@ -62,3 +69,21 @@ def test_Xb_Yb(key):
6269
est.fit(data, outcome, treatment, adjustment=adjustment, covariate=covariate, n_jobs=1)
6370
effect = est.estimate(test_data)
6471
assert effect.shape[0] == len(test_data)
72+
73+
74+
@if_policy_tree_ready
75+
@pytest.mark.parametrize('key', ['tree', 'grf', ])
76+
def test_Xb_Yb_tree(key):
77+
data, test_data, outcome, treatment, adjustment, covariate = _dgp.generate_data_x1b_y1()
78+
m = data[outcome].values.mean()
79+
data[outcome] = (data[outcome] > m).astype('int')
80+
test_data[outcome] = (test_data[outcome] > m).astype('int')
81+
82+
factory = ESTIMATOR_FACTORIES[key]()
83+
est = factory(data, outcome[0], treatment, 'binary', 'binary',
84+
adjustment=adjustment, covariate=covariate, random_state=123)
85+
assert est is not None
86+
87+
est.fit(data, outcome, treatment, adjustment=adjustment, covariate=covariate)
88+
effect = est.estimate(test_data)
89+
assert effect.shape[0] == len(test_data)

tests/why_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def test_policy_interpreter_discrete_x2_yb_tlearner():
172172
data[outcome] = (data[outcome] > m).astype('int')
173173
test_data[outcome] = (test_data[outcome] > m).astype('int')
174174
# why = Why()
175-
why = Why(estimator='ml', estimator_options=dict(learner='t', model='lr'))
175+
# why = Why(estimator='ml', estimator_options=dict(learner='t', model='lr'))
176+
why = Why(estimator='tlearner', estimator_options=dict(model='lr'))
176177
why.fit(data, outcome[0], treatment=treatment, adjustment=adjustment, covariate=covariate)
177178

178179
pi = why.policy_interpreter(test_data)
@@ -189,7 +190,6 @@ def test_policy_interpreter_discrete_x2_yb_dml():
189190
data[outcome] = (data[outcome] > m).astype('int')
190191
test_data[outcome] = (test_data[outcome] > m).astype('int')
191192
why = Why(estimator='dml')
192-
# why = Why(estimator='ml', estimator_options=dict(learner='t', model='lr'))
193193
why.fit(data, outcome[0], treatment=treatment, adjustment=adjustment, covariate=covariate)
194194

195195
pi = why.policy_interpreter(test_data)

ylearn/estimator_model/_factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
157157
is_discrete_treatment = x_task if isinstance(x_task, bool) else x_task != const.TASK_REGRESSION
158158
is_discrete_outcome = y_task if isinstance(y_task, bool) else y_task != const.TASK_REGRESSION
159159

160+
assert is_discrete_treatment, 'SLearner support discrete treatment only.'
161+
160162
return PermutedSLearner(
161163
model=self._model(data, task=y_task, estimator=self.model, random_state=random_state),
162164
is_discrete_outcome=is_discrete_outcome,
@@ -179,6 +181,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
179181
is_discrete_treatment = x_task if isinstance(x_task, bool) else x_task != const.TASK_REGRESSION
180182
is_discrete_outcome = y_task if isinstance(y_task, bool) else y_task != const.TASK_REGRESSION
181183

184+
assert is_discrete_treatment, 'TLearner support discrete treatment only.'
185+
182186
return PermutedTLearner(
183187
model=self._model(data, task=y_task, estimator=self.model, random_state=random_state),
184188
is_discrete_outcome=is_discrete_outcome,
@@ -202,6 +206,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
202206
is_discrete_treatment = x_task if isinstance(x_task, bool) else x_task != const.TASK_REGRESSION
203207
is_discrete_outcome = y_task if isinstance(y_task, bool) else y_task != const.TASK_REGRESSION
204208

209+
assert is_discrete_treatment, 'XLearner support discrete treatment only.'
210+
205211
if is_discrete_outcome:
206212
final_proba_model = self._model(
207213
data, task=const.TASK_REGRESSION, estimator=self.final_proba_model, random_state=random_state)
@@ -229,12 +235,41 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
229235
adjustment=None, covariate=None, instrument=None, random_state=None):
230236
from ylearn.estimator_model._permuted import PermutedCausalTree
231237

238+
is_discrete_treatment = x_task if isinstance(x_task, bool) else x_task != const.TASK_REGRESSION
239+
is_discrete_outcome = y_task if isinstance(y_task, bool) else y_task != const.TASK_REGRESSION
240+
241+
assert is_discrete_treatment, 'CausalTree support discrete treatment only.'
242+
232243
options = self.options.copy()
233244
if random_state is not None:
234245
options['random_state'] = random_state
246+
# options['is_discrete_outcome'] = is_discrete_outcome
247+
# options['is_discrete_treatment'] = is_discrete_treatment
248+
235249
return PermutedCausalTree(**options)
236250

237251

252+
@register()
253+
class GrfFactory(BaseEstimatorFactory):
254+
def __init__(self, **kwargs):
255+
self.options = kwargs.copy()
256+
257+
def __call__(self, data, outcome, treatment, y_task, x_task,
258+
adjustment=None, covariate=None, instrument=None, random_state=None):
259+
from ylearn.estimator_model._generalized_forest import GRForest
260+
261+
is_discrete_treatment = x_task if isinstance(x_task, bool) else x_task != const.TASK_REGRESSION
262+
is_discrete_outcome = y_task if isinstance(y_task, bool) else y_task != const.TASK_REGRESSION
263+
264+
options = self.options.copy()
265+
if random_state is not None:
266+
options['random_state'] = random_state
267+
options['is_discrete_outcome'] = is_discrete_outcome
268+
options['is_discrete_treatment'] = is_discrete_treatment
269+
270+
return GRForest(**options)
271+
272+
238273
@register()
239274
@register(name='bound')
240275
class ApproxBoundFactory(BaseEstimatorFactory):

0 commit comments

Comments
 (0)