Skip to content

Commit bb042d5

Browse files
authored
Cython implementation of GRF and CausalForestDML (#341)
* added backend option in orf, adding verbosity, restructuring static functions * added cython grf module that implements generalized random forests * added cuthon version of causal forest and causal forest dml * deprecating older CausalForest * updates to CF and ORF notebook * restructured dml into folder. Deprecated ForestDML by CausalForestDML. * Removed two legacy files in our main folder. * deprecating ensemble.SubsampledHonestForest * made drlearner use the non dprecated regression forest. * Enable setuptools build process * fixed flaky random_state test * fixed tests and api consistency * updated tables and library flow chart * enforce sklearn 0.24. * fixed _cross_val_predict * added option for max background samples to shap make computation more reasonable * fixed error_score param in gcvlist due to sklearn upgrade * added shap cells in DML notebook * added shap values to GRF notebook * fixed bug in the way input_feature_names where used in summary. enabled shap to use input featurenames * updated readme. removed autoreload from noteoboks * added shap specific notebook * updated dowhy notebook
1 parent 3df959d commit bb042d5

File tree

85 files changed

+14660
-3772
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+14660
-3772
lines changed

LICENSE

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,40 @@
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
SOFTWARE
22+
23+
24+
Parts of this software, in particular code contained in the modules econml.tree and
25+
econml.grf contain files that are forks from the scikit-learn git repository, or code
26+
snippets from that repository:
27+
https://github.com/scikit-learn/scikit-learn
28+
published under the following License.
29+
30+
BSD 3-Clause License
31+
32+
Copyright (c) 2007-2020 The scikit-learn developers.
33+
All rights reserved.
34+
35+
Redistribution and use in source and binary forms, with or without
36+
modification, are permitted provided that the following conditions are met:
37+
38+
* Redistributions of source code must retain the above copyright notice, this
39+
list of conditions and the following disclaimer.
40+
41+
* Redistributions in binary form must reproduce the above copyright notice,
42+
this list of conditions and the following disclaimer in the documentation
43+
and/or other materials provided with the distribution.
44+
45+
* Neither the name of the copyright holder nor the names of its
46+
contributors may be used to endorse or promote products derived from
47+
this software without specific prior written permission.
48+
49+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
50+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
51+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
52+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
53+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
54+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
55+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
56+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
57+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
58+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

README.md

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,7 @@ To install from source, see [For Developers](#for-developers) section below.
118118
treatment_effects = est.effect(X_test)
119119
lb, ub = est.effect_interval(X_test, alpha=0.05) # Confidence intervals via debiased lasso
120120
```
121-
122-
* Forest last stage
123-
124-
```Python
125-
from econml.dml import ForestDML
126-
from sklearn.ensemble import GradientBoostingRegressor
127121

128-
est = ForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())
129-
est.fit(Y, T, X=X, W=W)
130-
treatment_effects = est.effect(X_test)
131-
# Confidence intervals via Bootstrap-of-Little-Bags for forests
132-
lb, ub = est.effect_interval(X_test, alpha=0.05)
133-
```
134-
135122
* Generic Machine Learning last stage
136123

137124
```Python
@@ -152,16 +139,16 @@ To install from source, see [For Developers](#for-developers) section below.
152139
<summary>Causal Forests (click to expand)</summary>
153140

154141
```Python
155-
from econml.causal_forest import CausalForest
142+
from econml.dml import CausalForestDML
156143
from sklearn.linear_model import LassoCV
157144
# Use defaults
158-
est = CausalForest()
145+
est = CausalForestDML()
159146
# Or specify hyperparameters
160-
est = CausalForest(n_trees=500, min_leaf_size=10,
161-
max_depth=10, subsample_ratio=0.7,
162-
lambda_reg=0.01,
163-
discrete_treatment=False,
164-
model_T=LassoCV(), model_Y=LassoCV())
147+
est = CausalForestDML(criterion='het', n_estimators=500,
148+
min_samples_leaf=10,
149+
max_depth=10, max_samples=0.5,
150+
discrete_treatment=False,
151+
model_t=LassoCV(), model_y=LassoCV())
165152
est.fit(Y, T, X=X, W=W)
166153
treatment_effects = est.effect(X_test)
167154
# Confidence intervals via Bootstrap-of-Little-Bags for forests
@@ -354,7 +341,7 @@ treatment_effects = est.effect(X_test)
354341

355342
<details>
356343
<summary>Policy Interpreter of the CATE model (click to expand)</summary>
357-
344+
358345
```Python
359346
from econml.cate_interpreter import SingleTreePolicyInterpreter
360347
# We find a tree-based treatment policy based on the CATE model
@@ -366,7 +353,21 @@ treatment_effects = est.effect(X_test)
366353
plt.show()
367354
```
368355
![image](notebooks/images/dr_policy_tree.png)
369-
356+
357+
</details>
358+
359+
<details>
360+
<summary>SHAP values for the CATE model (click to expand)</summary>
361+
362+
```Python
363+
import shap
364+
from econml.dml import CausalForestDML
365+
est = CausalForestDML()
366+
est.fit(Y, T, X=X, W=W)
367+
shap_values = est.shap_values(X)
368+
shap.summary_plot(shap_values['Y0']['T0'])
369+
```
370+
370371
</details>
371372

372373
### Inference

azure-pipelines-steps.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
parameters:
77
body: []
8-
package: '.'
8+
package: '-e .'
99

1010
steps:
1111
- task: UsePythonVersion@0
@@ -24,7 +24,7 @@ steps:
2424
condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux'))
2525

2626
# Install the package
27-
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel && pip install ${{ parameters.package }}'
27+
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel Cython && pip install ${{ parameters.package }}'
2828
displayName: 'Install dependencies'
2929

3030
- ${{ parameters.body }}

azure-pipelines.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ jobs:
6969
- script: 'pip install --force-reinstall --no-cache-dir shap'
7070
displayName: 'Install public shap'
7171

72-
- script: 'pip install --force-reinstall scikit-learn==0.23.2'
73-
displayName: 'Install public old sklearn'
74-
7572
- script: 'python setup.py build_sphinx -W'
7673
displayName: 'Build documentation'
7774

@@ -81,7 +78,7 @@ jobs:
8178

8279
- script: 'python setup.py build_sphinx -b doctest'
8380
displayName: 'Run doctests'
84-
package: '.[automl]'
81+
package: '-e .[automl]'
8582

8683
- job: 'Notebooks'
8784
dependsOn: 'EvalChanges'

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@
211211
# Example configuration for intersphinx: refer to the Python standard library.
212212
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
213213
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
214-
'sklearn': ('https://scikit-learn.org/0.23/', None),
214+
'sklearn': ('https://scikit-learn.org/stable/', None),
215215
'matplotlib': ('https://matplotlib.org/', None)}
216216

217217
# -- Options for todo extension ----------------------------------------------

doc/map.svg

Lines changed: 14 additions & 14 deletions
Loading

doc/reference.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@ Public Module Reference
55
:toctree: _autosummary
66

77
econml.bootstrap
8-
econml.cate_estimator
98
econml.cate_interpreter
10-
econml.causal_forest
11-
econml.causal_tree
129
econml.deepiv
13-
econml.dgp
1410
econml.dml
1511
econml.drlearner
12+
econml.grf
1613
econml.inference
1714
econml.metalearners
1815
econml.ortho_forest
@@ -27,7 +24,12 @@ Private Module Reference
2724
:toctree: _autosummary
2825

2926
econml._ortho_learner
30-
econml._rlearner
27+
econml._cate_estimator
28+
econml._causal_tree
29+
econml.dml._rlearner
30+
econml.grf._base_grf
31+
econml.grf._base_grftree
32+
econml.grf._criterion
3133

3234
Scikit-Learn Extensions
3335
=======================
@@ -37,4 +39,3 @@ Scikit-Learn Extensions
3739

3840
econml.sklearn_extensions.linear_model
3941
econml.sklearn_extensions.model_selection
40-
econml.sklearn_extensions.ensemble

doc/spec/api.rst

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ The latter translates to estimating a local gradient around a treatment vector c
2727
\partial\tau(\vec{t}, \vec{x}) = \E\left[\nabla_{\vec{t}} Y(\vec{t}) | X=\vec{x}\right] \tag{marginal CATE}
2828
2929
We will refer to the latter as the *heterogeneous marginal effect*. [1]_
30-
Finally, we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:
31-
32-
.. math ::
33-
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}
3430

3531
We assume we have data that are generated from some collection policy. In particular, we assume that we have data of the form:
3632
:math:`\{Y_i(T_i), T_i, X_i, W_i, Z_i\}`, where :math:`Y_i(T_i)` is the observed outcome for the chosen treatment,
@@ -43,6 +39,19 @@ The variables :math:`X_i` can also be thought of as *control* variables, but the
4339
they are a subset of the controls with respect to which we want to measure treatment effect heterogeneity.
4440
We will refer to them as *features*.
4541

42+
Finally, some times we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:
43+
44+
.. math ::
45+
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}
46+
47+
Our package does not offer support for counterfactual prediction. However, for most of our estimators (the ones
48+
assuming a linear-in-treatment model), counterfactual prediction can be easily constructed by combining any baseline predictive model
49+
with our causal effect model, i.e. train any machine learning model :math:`b(\vec{t}, \vec{x})` to solve the regression/classification
50+
problem :math:`\E[Y | T=\vec{t}, X=\vec{x}]`, and then set :math:`\mu(vec{t}, \vec{x}) = \tau(\vec{t}, T, \vec{x}) + b(T, \vec{x})`,
51+
where :math:`T` is either the observed treatment for that sample under the observational policy or the treatment
52+
that the observational policy would have assigned to that sample. These auxiliary ML models can be trained
53+
with any machine learning package outside of EconML.
54+
4655
.. rubric::
4756
Structural Equation Formulation
4857

doc/spec/comparison.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ Detailed estimator comparison
1919
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
2020
| :class:`.LinearDRLearner` | Categorical | | Yes | | Projected | | Yes | |
2121
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
22-
| :class:`.ForestDML` | 1-d/Binary | | Yes | Yes | | Yes | | Yes |
22+
| :class:`.CausalForestDML` | Any | | Yes | Yes | | Yes | Yes | Yes |
2323
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
2424
| :class:`.ForestDRLearner` | Categorical | | Yes | | | | Yes | Yes |
2525
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
26-
| :class:`.ContinuousTreatmentOrthoForest` | Continuous | | Yes | Yes | | | Yes | Yes |
26+
| :class:`.DMLOrthoForest` | Any | | Yes | Yes | | | Yes | Yes |
2727
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
28-
| :class:`.DiscreteTreatmentOrthoForest` | Categorical | | Yes | | | | Yes | Yes |
28+
| :class:`.DROrthoForest` | Categorical | | Yes | | | | Yes | Yes |
2929
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
3030
| :mod:`~econml.metalearners` | Categorical | | | | | Yes | Yes | Yes |
3131
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+

0 commit comments

Comments
 (0)