Skip to content

Commit 35c5418

Browse files
authored
Averate Treatment Effect methods to all estimators (#365)
* added ate inference methods
1 parent f0b0e5b commit 35c5418

File tree

9 files changed

+2427
-1904
lines changed

9 files changed

+2427
-1904
lines changed

econml/_cate_estimator.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,54 @@ def marginal_effect(self, T, X=None):
181181
"""
182182
pass
183183

184+
def ate(self, X=None, *, T0, T1):
185+
"""
186+
Calculate the average treatment effect :math:`E_X[\\tau(X, T0, T1)]`.
187+
188+
The effect is calculated between the two treatment points and is averaged over
189+
the population of X variables.
190+
191+
Parameters
192+
----------
193+
T0: (m, d_t) matrix or vector of length m
194+
Base treatments for each sample
195+
T1: (m, d_t) matrix or vector of length m
196+
Target treatments for each sample
197+
X: optional (m, d_x) matrix
198+
Features for each sample
199+
200+
Returns
201+
-------
202+
τ: float or (d_y,) array
203+
Average treatment effects on each outcome
204+
Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
205+
"""
206+
return np.mean(self.effect(X=X, T0=T0, T1=T1), axis=0)
207+
208+
def marginal_ate(self, T, X=None):
209+
"""
210+
Calculate the average marginal effect :math:`E_{T, X}[\\partial\\tau(T, X)]`.
211+
212+
The marginal effect is calculated around a base treatment
213+
point and averaged over the population of X.
214+
215+
Parameters
216+
----------
217+
T: (m, d_t) matrix
218+
Base treatments for each sample
219+
X: optional (m, d_x) matrix
220+
Features for each sample
221+
222+
Returns
223+
-------
224+
grad_tau: (d_y, d_t) array
225+
Average marginal effects on each outcome
226+
Note that when Y or T is a vector rather than a 2-dimensional array,
227+
the corresponding singleton dimensions in the output will be collapsed
228+
(e.g. if both are vectors, then the output of this method will be a scalar)
229+
"""
230+
return np.mean(self.marginal_effect(T, X=X), axis=0)
231+
184232
def _expand_treatments(self, X=None, *Ts):
185233
"""
186234
Given a set of features and treatments, return possibly modified features and treatments.
@@ -303,6 +351,101 @@ def marginal_effect_inference(self, T, X=None):
303351
"""
304352
pass
305353

354+
@_defer_to_inference
355+
def ate_interval(self, X=None, *, T0, T1, alpha=0.1):
356+
""" Confidence intervals for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
357+
by the model. Available only when ``inference`` is not ``None``, when
358+
calling the fit method.
359+
360+
Parameters
361+
----------
362+
X: optional (m, d_x) matrix
363+
Features for each sample
364+
T0: optional (m, d_t) matrix or vector of length m (Default=0)
365+
Base treatments for each sample
366+
T1: optional (m, d_t) matrix or vector of length m (Default=1)
367+
Target treatments for each sample
368+
alpha: optional float in [0, 1] (Default=0.1)
369+
The overall level of confidence of the reported interval.
370+
The alpha/2, 1-alpha/2 confidence interval is reported.
371+
372+
Returns
373+
-------
374+
lower, upper : tuple(type of :meth:`ate(X, T0, T1)<ate>`, type of :meth:`ate(X, T0, T1))<ate>` )
375+
The lower and the upper bounds of the confidence interval for each quantity.
376+
"""
377+
pass
378+
379+
@_defer_to_inference
380+
def ate_inference(self, X=None, *, T0, T1):
381+
""" Inference results for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
382+
by the model. Available only when ``inference`` is not ``None``, when
383+
calling the fit method.
384+
385+
Parameters
386+
----------
387+
X: optional (m, d_x) matrix
388+
Features for each sample
389+
T0: optional (m, d_t) matrix or vector of length m (Default=0)
390+
Base treatments for each sample
391+
T1: optional (m, d_t) matrix or vector of length m (Default=1)
392+
Target treatments for each sample
393+
394+
Returns
395+
-------
396+
PopulationSummaryResults: object
397+
The inference results instance contains prediction and prediction standard error and
398+
can on demand calculate confidence interval, z statistic and p value. It can also output
399+
a dataframe summary of these inference results.
400+
"""
401+
pass
402+
403+
@_defer_to_inference
404+
def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
405+
""" Confidence intervals for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
406+
by the model. Available only when ``inference`` is not ``None``, when
407+
calling the fit method.
408+
409+
Parameters
410+
----------
411+
T: (m, d_t) matrix
412+
Base treatments for each sample
413+
X: optional (m, d_x) matrix or None (Default=None)
414+
Features for each sample
415+
alpha: optional float in [0, 1] (Default=0.1)
416+
The overall level of confidence of the reported interval.
417+
The alpha/2, 1-alpha/2 confidence interval is reported.
418+
419+
Returns
420+
-------
421+
lower, upper : tuple(type of :meth:`marginal_ate(T, X)<marginal_ate>`, \
422+
type of :meth:`marginal_ate(T, X)<marginal_ate>` )
423+
The lower and the upper bounds of the confidence interval for each quantity.
424+
"""
425+
pass
426+
427+
@_defer_to_inference
428+
def marginal_ate_inference(self, T, X=None):
429+
""" Inference results for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
430+
by the model. Available only when ``inference`` is not ``None``, when
431+
calling the fit method.
432+
433+
Parameters
434+
----------
435+
T: (m, d_t) matrix
436+
Base treatments for each sample
437+
X: optional (m, d_x) matrix or None (Default=None)
438+
Features for each sample
439+
440+
Returns
441+
-------
442+
PopulationSummaryResults: object
443+
The inference results instance contains prediction and prediction standard error and
444+
can on demand calculate confidence interval, z statistic and p value. It can also output
445+
a dataframe summary of these inference results.
446+
"""
447+
pass
448+
306449

307450
class LinearCateEstimator(BaseCateEstimator):
308451
"""Base class for all CATE estimators with linear treatment effects in this package."""
@@ -457,6 +600,79 @@ def const_marginal_effect_inference(self, X=None):
457600
"""
458601
pass
459602

603+
def const_marginal_ate(self, X=None):
604+
"""
605+
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
606+
607+
Parameters
608+
----------
609+
X: optional (m, d_x) matrix or None (Default=None)
610+
Features for each sample.
611+
612+
Returns
613+
-------
614+
theta: (d_y, d_t) matrix
615+
Average constant marginal CATE of each treatment on each outcome.
616+
Note that when Y or T is a vector rather than a 2-dimensional array,
617+
the corresponding singleton dimensions in the output will be collapsed
618+
(e.g. if both are vectors, then the output of this method will be a scalar)
619+
"""
620+
return np.mean(self.const_marginal_effect(X=X), axis=0)
621+
622+
@BaseCateEstimator._defer_to_inference
623+
def const_marginal_ate_interval(self, X=None, *, alpha=0.1):
624+
""" Confidence intervals for the quantities :math:`E_X[\\theta(X)]` produced
625+
by the model. Available only when ``inference`` is not ``None``, when
626+
calling the fit method.
627+
628+
Parameters
629+
----------
630+
X: optional (m, d_x) matrix or None (Default=None)
631+
Features for each sample
632+
alpha: optional float in [0, 1] (Default=0.1)
633+
The overall level of confidence of the reported interval.
634+
The alpha/2, 1-alpha/2 confidence interval is reported.
635+
636+
Returns
637+
-------
638+
lower, upper : tuple(type of :meth:`const_marginal_ate(X)<const_marginal_ate>` ,\
639+
type of :meth:`const_marginal_ate(X)<const_marginal_ate>` )
640+
The lower and the upper bounds of the confidence interval for each quantity.
641+
"""
642+
pass
643+
644+
@BaseCateEstimator._defer_to_inference
645+
def const_marginal_ate_inference(self, X=None):
646+
""" Inference results for the quantities :math:`E_X[\\theta(X)]` produced
647+
by the model. Available only when ``inference`` is not ``None``, when
648+
calling the fit method.
649+
650+
Parameters
651+
----------
652+
X: optional (m, d_x) matrix or None (Default=None)
653+
Features for each sample
654+
655+
Returns
656+
-------
657+
PopulationSummaryResults: object
658+
The inference results instance contains prediction and prediction standard error and
659+
can on demand calculate confidence interval, z statistic and p value. It can also output
660+
a dataframe summary of these inference results.
661+
"""
662+
pass
663+
664+
def marginal_ate(self, T, X=None):
665+
return self.const_marginal_ate(X=X)
666+
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__
667+
668+
def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
669+
return self.const_marginal_ate_interval(X=X, alpha=alpha)
670+
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__
671+
672+
def marginal_ate_inference(self, T, X=None):
673+
return self.const_marginal_ate_inference(X=X)
674+
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__
675+
460676
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
461677
""" Shap value for the final stage models (const_marginal_effect)
462678
@@ -524,6 +740,18 @@ def effect(self, X=None, *, T0=0, T1=1):
524740
return super().effect(X, T0=T0, T1=T1)
525741
effect.__doc__ = BaseCateEstimator.effect.__doc__
526742

743+
def ate(self, X=None, *, T0=0, T1=1):
744+
return super().ate(X=X, T0=T0, T1=T1)
745+
ate.__doc__ = BaseCateEstimator.ate.__doc__
746+
747+
def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
748+
return super().ate_interval(X=X, T0=T0, T1=T1, alpha=alpha)
749+
ate_interval.__doc__ = BaseCateEstimator.ate_interval.__doc__
750+
751+
def ate_inference(self, X=None, *, T0=0, T1=1):
752+
return super().ate_inference(X=X, T0=T0, T1=T1)
753+
ate_inference.__doc__ = BaseCateEstimator.ate_inference.__doc__
754+
527755

528756
class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
529757
"""
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from ._interpreters import SingleTreeCateInterpreter, SingleTreePolicyInterpreter
5+
6+
__all__ = ["SingleTreeCateInterpreter",
7+
"SingleTreePolicyInterpreter"]

econml/cate_interpreter.py renamed to econml/cate_interpreter/_interpreters.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class _SingleTreeInterpreter(metaclass=abc.ABCMeta):
1414

1515
tree_model = None
16+
node_dict = None
1617

1718
@abc.abstractmethod
1819
def interpret(self, cate_estimator, X):
@@ -156,7 +157,7 @@ def export_graphviz(self, out_file=None, feature_names=None,
156157
exporter = self._make_dot_exporter(out_file=out_file, feature_names=feature_names, filled=filled,
157158
leaves_parallel=leaves_parallel, rotate=rotate, rounded=rounded,
158159
special_characters=special_characters, precision=precision)
159-
exporter.export(self.tree_model)
160+
exporter.export(self.tree_model, node_dict=self.node_dict)
160161

161162
if return_string:
162163
return out_file.getvalue()
@@ -249,7 +250,7 @@ def plot(self, ax=None, title=None, feature_names=None,
249250
check_is_fitted(self.tree_model, 'tree_')
250251
exporter = self._make_mpl_exporter(title=title, feature_names=feature_names, filled=filled,
251252
rounded=rounded, precision=precision, fontsize=fontsize)
252-
exporter.export(self.tree_model, ax=ax)
253+
exporter.export(self.tree_model, node_dict=self.node_dict, ax=ax)
253254

254255

255256
class SingleTreeCateInterpreter(_SingleTreeInterpreter):
@@ -261,7 +262,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
261262
include_uncertainty : bool, optional, default False
262263
Whether to include confidence interval information when building a
263264
simplified model of the cate model. If set to True, then
264-
cate estimator needs to support the `effect_interval` method.
265+
cate estimator needs to support the `const_marginal_ate_inference` method.
265266
266267
uncertainty_level : double, optional, default .05
267268
The uncertainty level for the confidence intervals to be constructed
@@ -270,6 +271,11 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
270271
in a leaf have similar target prediction but also similar alpha
271272
confidence intervals.
272273
274+
uncertainty_only_on_leaves : bool, optional, default True
275+
Whether uncertainty information should be displayed only on leaf nodes.
276+
If False, then interpretation can be slightly slower, especially for cate
277+
models that have a computationally expensive inference method.
278+
273279
splitter : string, optional, default "best"
274280
The strategy used to choose the split at each node. Supported
275281
strategies are "best" to choose the best split and "random" to choose
@@ -335,6 +341,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
335341
def __init__(self,
336342
include_model_uncertainty=False,
337343
uncertainty_level=.1,
344+
uncertainty_only_on_leaves=True,
338345
splitter="best",
339346
max_depth=None,
340347
min_samples_split=2,
@@ -346,6 +353,7 @@ def __init__(self,
346353
min_impurity_decrease=0.):
347354
self.include_uncertainty = include_model_uncertainty
348355
self.uncertainty_level = uncertainty_level
356+
self.uncertainty_only_on_leaves = uncertainty_only_on_leaves
349357
self.criterion = "mse"
350358
self.splitter = splitter
351359
self.max_depth = max_depth
@@ -370,20 +378,23 @@ def interpret(self, cate_estimator, X):
370378
min_impurity_decrease=self.min_impurity_decrease)
371379
y_pred = cate_estimator.const_marginal_effect(X)
372380

373-
assert all(d == 1 for d in y_pred.shape[1:]), ("Interpretation is only available for "
374-
"single-dimensional treatments and outcomes")
375-
376-
if y_pred.ndim != 2:
377-
y_pred = y_pred.reshape(-1, 1)
378-
379-
if self.include_uncertainty:
380-
y_lower, y_upper = cate_estimator.const_marginal_effect_interval(X, alpha=self.uncertainty_level)
381-
if y_lower.ndim != 2:
382-
y_lower = y_lower.reshape(-1, 1)
383-
y_upper = y_upper.reshape(-1, 1)
384-
y_pred = np.hstack([y_pred, y_lower, y_upper])
385-
self.tree_model.fit(X, y_pred)
386-
381+
self.tree_model.fit(X, y_pred.reshape((y_pred.shape[0], -1)))
382+
paths = self.tree_model.decision_path(X)
383+
node_dict = {}
384+
for node_id in range(paths.shape[1]):
385+
mask = paths.getcol(node_id).toarray().flatten().astype(bool)
386+
Xsub = X[mask]
387+
if (self.include_uncertainty and
388+
((not self.uncertainty_only_on_leaves) or (self.tree_model.tree_.children_left[node_id] < 0))):
389+
res = cate_estimator.const_marginal_ate_inference(Xsub)
390+
node_dict[node_id] = {'mean': res.mean_point,
391+
'std': res.std_point,
392+
'ci': res.conf_int_mean(alpha=self.uncertainty_level)}
393+
else:
394+
cate_node = y_pred[mask]
395+
node_dict[node_id] = {'mean': np.mean(cate_node, axis=0),
396+
'std': np.std(cate_node, axis=0)}
397+
self.node_dict = node_dict
387398
return self
388399

389400
def _make_dot_exporter(self, *, out_file, feature_names, filled,

0 commit comments

Comments
 (0)