Skip to content

Commit 69fadc3

Browse files
authored
Add doc for interpretability (#369)
* add doc for interpretability * update the shap dependency to the release version and point specific version to doc test
1 parent fb34846 commit 69fadc3

File tree

7 files changed

+123
-6
lines changed

7 files changed

+123
-6
lines changed

azure-pipelines.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ jobs:
6464
displayName: 'Install graphviz'
6565

6666
- script: 'pip install sklearn-contrib-lightning'
67-
displayName: 'Install lightning'
68-
69-
- script: 'pip install --force-reinstall --no-cache-dir shap'
70-
displayName: 'Install public shap'
67+
displayName: 'Install lightning'
68+
69+
- script: 'pip install git+https://github.com/slundberg/shap.git@d1d2700acc0259f211934373826d5ff71ad514de'
70+
displayName: 'Install specific version of shap'
7171

7272
- script: 'python setup.py build_sphinx -W'
7373
displayName: 'Build documentation'

doc/reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Private Module Reference
2727
econml._ortho_learner
2828
econml._cate_estimator
2929
econml._causal_tree
30+
econml._shap
3031
econml.dml._rlearner
3132
econml.grf._base_grf
3233
econml.grf._base_grftree

doc/spec/interpretability.rst

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
Interpretability
2+
================
3+
4+
Our package offers multiple interpretability tools to better understand the final model CATE.
5+
6+
7+
Tree Interpreter
8+
----------------
9+
10+
Tree Interpreter provides a presentation-ready summary of the key features that explain the biggest differences in responsiveness to an intervention.
11+
12+
:class:`.SingleTreeCateInterpreter` trains a single shallow decision tree for the treatment effect :math:`\theta(X)` you learned from any of
13+
our available CATE estimators on a small set of feature :math:`X` that you are interested to learn heterogeneity from. The model will split on the cutoff
14+
points that maximize the treatment effect difference in each leaf. Finally each leaf will be a subgroup of samples that respond to a treatment differently
15+
from other leaves.
16+
17+
For instance:
18+
19+
.. testsetup::
20+
21+
import numpy as np
22+
X = np.random.choice(np.arange(5), size=(100,3))
23+
Y = np.random.normal(size=(100,2))
24+
y = np.random.normal(size=(100,))
25+
T = np.random.choice(np.arange(3), size=(100,2))
26+
t = T[:,0]
27+
W = np.random.normal(size=(100,2))
28+
29+
30+
.. testcode::
31+
32+
from econml.cate_interpreter import SingleTreeCateInterpreter
33+
from econml.dml import LinearDML
34+
est = LinearDML()
35+
est.fit(y, t, X=X, W=W)
36+
intrp = SingleTreeCateInterpreter(include_model_uncertainty=True, max_depth=2, min_samples_leaf=10)
37+
# We interpret the CATE model's behavior based on the features used for heterogeneity
38+
intrp.interpret(est, X)
39+
# Plot the tree
40+
intrp.plot(feature_names=['A', 'B', 'C'], fontsize=12)
41+
42+
Policy Interpreter
43+
------------------
44+
Policy Interpreter offers similar functionality but taking cost into consideration.
45+
46+
Instead of fitting a tree to learn groups that have a different treatment effect, :class:`.SingleTreePolicyInterpreter` tries to split the samples into different treatment groups.
47+
So in the case of binary treatments it tries to create sub-groups such that all samples within the group have either all positive effect or all negative effect. Thus it tries to
48+
separate responders from non-responders, as opposed to trying to find groups that have different levels of response.
49+
50+
This way you can construct an interpretable personalized policy where you treat the groups with a postive effect and don't treat the group with a negative effect.
51+
Our policy tree provides the recommended treatment at each leaf node.
52+
53+
54+
For instance:
55+
56+
.. testcode::
57+
58+
from econml.cate_interpreter import SingleTreePolicyInterpreter
59+
# We find a tree-based treatment policy based on the CATE model
60+
# sample_treatment_costs is the cost of treatment. Policy will treat if effect is above this cost.
61+
intrp = SingleTreePolicyInterpreter(risk_level=None, max_depth=2, min_samples_leaf=1,min_impurity_decrease=.001)
62+
intrp.interpret(est, X, sample_treatment_costs=0.02)
63+
# Plot the tree
64+
intrp.plot(feature_names=['A', 'B', 'C'], fontsize=12)
65+
66+
67+
SHAP
68+
----
69+
70+
`SHAP <https://shap.readthedocs.io/en/latest/>`_ is a popular open source library for interpreting black-box machine learning
71+
models using the Shapley values methodology (see e.g. [Lundberg2017]_).
72+
73+
Similar to how black-box predictive machine learning models can be explained with SHAP, we can also explain black-box effect
74+
heterogeneity models. This approach provides an explanation as to why a heterogeneous causal effect model produced larger or
75+
smaller effect values for particular segments of the population. Which were the features that lead to such differentiation?
76+
This question is easy to address when the model is succinctly described, such as the case of linear heterogneity models,
77+
where one can simply investigate the coefficients of the model. However, it becomes hard when one starts using more expressive
78+
models, such as Random Forests and Causal Forests to model effect hetergoeneity. SHAP values can be of immense help to
79+
understand the leading factors of effect hetergoeneity that the model picked up from the training data.
80+
81+
Our package offers seamless integration with the SHAP library. Every CATE estimator has a method `shap_values`, which returns the
82+
SHAP value explanation of the estimators output for every treatment and outcome pair. These values can then be visualized with
83+
the plethora of visualizations that the SHAP library offers. Moreover, whenever possible our library invokes fast specialized
84+
algorithms from the SHAP library, for each type of final model, which can greatly reduce computation times.
85+
86+
For instance:
87+
88+
.. testcode::
89+
90+
import shap
91+
from econml.dml import LinearDML
92+
est = LinearDML()
93+
est.fit(y, t, X=X, W=W)
94+
shap_values = est.shap_values(X)
95+
# local view: explain hetergoeneity for a given observation
96+
ind=0
97+
shap.plots.force(shap_values["Y0"]["T0"][ind], matplotlib=True)
98+
# global view: explain hetergoeneity for a sample of dataset
99+
shap.summary_plot(shap_values['Y0']['T0'])

doc/spec/references.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,9 @@ References
108108
.. [Friedberg2018]
109109
Friedberg, R., Tibshirani, J., Athey, S., & Wager, S. (2018).
110110
Local linear forests.
111-
arXiv preprint arXiv:1807.11408.
111+
arXiv preprint arXiv:1807.11408.
112+
113+
.. [Lundberg2017]
114+
Lundberg, S., Lee, S. (2017).
115+
A Unified Approach to Interpreting Model Predictions.
116+
URL https://arxiv.org/abs/1705.07874

doc/spec/spec.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The EconML Python SDK, developed by the ALICE team at MSR New England, incorpora
2020
estimation
2121
estimation_iv
2222
inference
23+
interpretability
2324
references
2425

2526
.. todo::

econml/_shap.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
"""Helper functions to get shap values for different cate estimators.
5+
6+
References
7+
----------
8+
Scott Lundberg, Su-In Lee (2017)
9+
A Unified Approach to Interpreting Model Predictions.
10+
NeurIPS, https://arxiv.org/abs/1705.07874
11+
12+
13+
"""
14+
415
import shap
516
from collections import defaultdict
617
import numpy as np

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ install_requires =
5050
graphviz
5151
matplotlib
5252
pandas < 1.1
53-
shap @ git+https://github.com/slundberg/shap.git
53+
shap ~= 0.38.1
5454
test_suite = econml.tests
5555
tests_require =
5656
pytest

0 commit comments

Comments
 (0)