Skip to content

Commit b5c25cc

Browse files
authored
RScorer class for causal model selection (#361)
* added rscorer for model selection * added readme on model selection
1 parent 9a96875 commit b5c25cc

File tree

11 files changed

+1129
-14
lines changed

11 files changed

+1129
-14
lines changed

README.md

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ treatment_effects = est.effect(X_test)
319319
```
320320
</details>
321321

322-
See the <a href="#references">References</a> section for more details.
322+
See the <a href="#references">References</a> section for more details.
323323

324324
### Interpretability
325325
<details>
@@ -370,6 +370,54 @@ treatment_effects = est.effect(X_test)
370370

371371
</details>
372372

373+
374+
### Causal Model Selection and Cross-Validation
375+
376+
377+
<details>
378+
<summary>Causal model selection with the `RScorer` (click to expand)</summary>
379+
380+
```Python
381+
from econml.score import Rscorer
382+
383+
# split data in train-validation
384+
X_train, X_val, T_train, T_val, Y_train, Y_val = train_test_split(X, T, y, test_size=.4)
385+
386+
# define list of CATE estimators to select among
387+
reg = lambda: RandomForestRegressor(min_samples_leaf=20)
388+
clf = lambda: RandomForestClassifier(min_samples_leaf=20)
389+
models = [('ldml', LinearDML(model_y=reg(), model_t=clf(), discrete_treatment=True,
390+
linear_first_stages=False, n_splits=3)),
391+
('xlearner', XLearner(models=reg(), cate_models=reg(), propensity_model=clf())),
392+
('dalearner', DomainAdaptationLearner(models=reg(), final_models=reg(), propensity_model=clf())),
393+
('slearner', SLearner(overall_model=reg())),
394+
('drlearner', DRLearner(model_propensity=clf(), model_regression=reg(),
395+
model_final=reg(), n_splits=3)),
396+
('rlearner', NonParamDML(model_y=reg(), model_t=clf(), model_final=reg(),
397+
discrete_treatment=True, n_splits=3)),
398+
('dml3dlasso', DML(model_y=reg(), model_t=clf(),
399+
model_final=LassoCV(cv=3, fit_intercept=False),
400+
discrete_treatment=True,
401+
featurizer=PolynomialFeatures(degree=3),
402+
linear_first_stages=False, n_splits=3))
403+
]
404+
405+
# fit cate models on train data
406+
models = [(name, mdl.fit(Y_train, T_train, X=X_train)) for name, mdl in models]
407+
408+
# score cate models on validation data
409+
scorer = RScorer(model_y=reg(), model_t=clf(),
410+
discrete_treatment=True, n_splits=3, mc_iters=2, mc_agg='median')
411+
scorer.fit(Y_val, T_val, X=X_val)
412+
rscore = [scorer.score(mdl) for _, mdl in models]
413+
# select the best model
414+
mdl, _ = scorer.best_model([mdl for _, mdl in models])
415+
# create weighted ensemble model based on score performance
416+
mdl, _ = scorer.ensemble([mdl for _, mdl in models])
417+
```
418+
419+
</details>
420+
373421
### Inference
374422

375423
Whenever inference is enabled, then one can get a more structure `InferenceResults` object with more elaborate inference information, such

doc/reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Public Module Reference
1414
econml.metalearners
1515
econml.ortho_forest
1616
econml.ortho_iv
17+
econml.score
1718
econml.two_stage_least_squares
1819
econml.utilities
1920

econml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
'cate_interpreter', 'causal_forest',
44
'data', 'deepiv', 'dml', 'drlearner', 'inference',
55
'metalearners', 'ortho_forest', 'ortho_iv',
6-
'sklearn_extensions', 'tree',
6+
'score', 'sklearn_extensions', 'tree',
77
'two_stage_least_squares', 'utilities']

econml/dml/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,31 @@
77
(aka residual outcome and residual treatment).
88
Then estimates a CATE model by regressing the residual outcome on the residual treatment
99
in a manner that accounts for heterogeneity in the regression coefficient, with respect
10-
to X.
10+
to X. For the theoretical foundations of these methods see [dml]_, [rlearner]_, [paneldml]_,
11+
[lassodml]_, [ortholearner]_.
1112
1213
References
1314
----------
1415
15-
\\ V. Chernozhukov, D. Chetverikov, M. Demirer, E. Duflo, C. Hansen, and a. W. Newey.
16+
.. [dml] V. Chernozhukov, D. Chetverikov, M. Demirer, E. Duflo, C. Hansen, and a. W. Newey.
1617
Double Machine Learning for Treatment and Causal Parameters.
17-
https://arxiv.org/abs/1608.00060, 2016.
18+
`<https://arxiv.org/abs/1608.00060>`_, 2016.
1819
19-
\\ X. Nie and S. Wager.
20+
.. [rlearner] X. Nie and S. Wager.
2021
Quasi-Oracle Estimation of Heterogeneous Treatment Effects.
21-
arXiv preprint arXiv:1712.04912, 2017. URL http://arxiv.org/abs/1712.04912.
22+
arXiv preprint arXiv:1712.04912, 2017. URL `<http://arxiv.org/abs/1712.04912>`_.
2223
23-
\\ V. Chernozhukov, M. Goldman, V. Semenova, and M. Taddy.
24+
.. [paneldml] V. Chernozhukov, M. Goldman, V. Semenova, and M. Taddy.
2425
Orthogonal Machine Learning for Demand Estimation: High Dimensional Causal Inference in Dynamic Panels.
25-
https://arxiv.org/abs/1712.09988, December 2017.
26+
`<https://arxiv.org/abs/1712.09988>`_, December 2017.
2627
27-
\\ V. Chernozhukov, D. Nekipelov, V. Semenova, and V. Syrgkanis.
28+
.. [lassodml] V. Chernozhukov, D. Nekipelov, V. Semenova, and V. Syrgkanis.
2829
Two-Stage Estimation with a High-Dimensional Second Stage.
29-
https://arxiv.org/abs/1806.04823, 2018.
30+
`<https://arxiv.org/abs/1806.04823>`_, 2018.
3031
31-
\\ Dylan Foster, Vasilis Syrgkanis (2019).
32+
.. [ortholearner] Dylan Foster, Vasilis Syrgkanis (2019).
3233
Orthogonal Statistical Learning.
33-
ACM Conference on Learning Theory. https://arxiv.org/abs/1901.09036
34+
ACM Conference on Learning Theory. `<https://arxiv.org/abs/1901.09036>`_
3435
3536
"""
3637

econml/grf/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
""" An efficient Cython implementation of Generalized Random Forests [grf]_ and special
5+
case python classes.
6+
7+
References
8+
----------
9+
.. [grf] Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests."
10+
The Annals of Statistics 47.2 (2019): 1148-1178
11+
https://arxiv.org/pdf/1610.01271.pdf
12+
"""
13+
414
from ._criterion import LinearMomentGRFCriterion, LinearMomentGRFCriterionMSE
515
from .classes import CausalForest, CausalIVForest, RegressionForest, MultiOutputGRF
616

econml/metalearners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""Metalearners for heterogeneous treatment effects in the context of discrete treatments.
55
6-
For more details on these CATE methods, see <https://arxiv.org/abs/1706.03461>
6+
For more details on these CATE methods, see `<https://arxiv.org/abs/1706.03461>`_
77
(Künzel S., Sekhon J., Bickel P., Yu B.) on Arxiv.
88
"""
99

econml/score/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
"""
5+
A suite of scoring methods for scoring CATE models out-of-sample for the
6+
purpose of model selection.
7+
"""
8+
9+
from .rscorer import RScorer
10+
from .ensemble_cate import EnsembleCateEstimator
11+
12+
__all__ = ['RScorer',
13+
'EnsembleCateEstimator']

econml/score/ensemble_cate.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import numpy as np
5+
from sklearn.utils.validation import check_array
6+
from .._cate_estimator import BaseCateEstimator, LinearCateEstimator
7+
8+
9+
class EnsembleCateEstimator:
10+
""" A CATE estimator that represents a weighted ensemble of many
11+
CATE estimators. Returns their weighted effect prediction.
12+
13+
Parameters
14+
----------
15+
cate_models : list of BaseCateEstimator objects
16+
A list of fitted cate estimator objects that will be used in the ensemble.
17+
The models are passed by reference, and not copied internally, because we
18+
need the fitted objects, so any change to the passed models will affect
19+
the internal predictions (e.g. if the input models are refitted).
20+
weights : np.ndarray of shape (len(cate_models),)
21+
The weight placed on each model. Weights must be non-positive. The
22+
ensemble will predict effects based on the weighted average predictions
23+
of the cate_models estiamtors, weighted by the corresponding weight in `weights`.
24+
"""
25+
26+
def __init__(self, *, cate_models, weights):
27+
self.cate_models = cate_models
28+
self.weights = weights
29+
30+
def effect(self, X=None, *, T0=0, T1=1):
31+
return np.average([mdl.effect(X=X, T0=T0, T1=T1) for mdl in self.cate_models],
32+
weights=self.weights, axis=0)
33+
effect.__doc__ = BaseCateEstimator.effect.__doc__
34+
35+
def marginal_effect(self, T, X=None):
36+
return np.average([mdl.marginal_effect(T, X=X) for mdl in self.cate_models],
37+
weights=self.weights, axis=0)
38+
marginal_effect.__doc__ = BaseCateEstimator.marginal_effect.__doc__
39+
40+
def const_marginal_effect(self, X=None):
41+
if np.any([not hasattr(mdl, 'const_marginal_effect') for mdl in self.cate_models]):
42+
raise ValueError("One of the base CATE models in parameter `cate_models` does not support "
43+
"the `const_marginal_effect` method.")
44+
return np.average([mdl.const_marginal_effect(X=X) for mdl in self.cate_models],
45+
weights=self.weights, axis=0)
46+
const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__
47+
48+
@property
49+
def cate_models(self):
50+
return self._cate_models
51+
52+
@cate_models.setter
53+
def cate_models(self, value):
54+
if (not isinstance(value, list)) or (not np.all([isinstance(model, BaseCateEstimator) for model in value])):
55+
raise ValueError('Parameter `cate_models` should be a list of `BaseCateEstimator` objects.')
56+
self._cate_models = value
57+
58+
@property
59+
def weights(self):
60+
return self._weights
61+
62+
@weights.setter
63+
def weights(self, value):
64+
weights = check_array(value, accept_sparse=False, ensure_2d=False, allow_nd=False, dtype='numeric',
65+
force_all_finite=True)
66+
if np.any(weights < 0):
67+
raise ValueError("All weights in parameter `weights` must be non-negative.")
68+
self._weights = weights

0 commit comments

Comments
 (0)