@@ -181,6 +181,54 @@ def marginal_effect(self, T, X=None):
181181 """
182182 pass
183183
184+ def ate (self , X = None , * , T0 , T1 ):
185+ """
186+ Calculate the average treatment effect :math:`E_X[\\ tau(X, T0, T1)]`.
187+
188+ The effect is calculated between the two treatment points and is averaged over
189+ the population of X variables.
190+
191+ Parameters
192+ ----------
193+ T0: (m, d_t) matrix or vector of length m
194+ Base treatments for each sample
195+ T1: (m, d_t) matrix or vector of length m
196+ Target treatments for each sample
197+ X: optional (m, d_x) matrix
198+ Features for each sample
199+
200+ Returns
201+ -------
202+ τ: float or (d_y,) array
203+ Average treatment effects on each outcome
204+ Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
205+ """
206+ return np .mean (self .effect (X = X , T0 = T0 , T1 = T1 ), axis = 0 )
207+
208+ def marginal_ate (self , T , X = None ):
209+ """
210+ Calculate the average marginal effect :math:`E_{T, X}[\\ partial\\ tau(T, X)]`.
211+
212+ The marginal effect is calculated around a base treatment
213+ point and averaged over the population of X.
214+
215+ Parameters
216+ ----------
217+ T: (m, d_t) matrix
218+ Base treatments for each sample
219+ X: optional (m, d_x) matrix
220+ Features for each sample
221+
222+ Returns
223+ -------
224+ grad_tau: (d_y, d_t) array
225+ Average marginal effects on each outcome
226+ Note that when Y or T is a vector rather than a 2-dimensional array,
227+ the corresponding singleton dimensions in the output will be collapsed
228+ (e.g. if both are vectors, then the output of this method will be a scalar)
229+ """
230+ return np .mean (self .marginal_effect (T , X = X ), axis = 0 )
231+
184232 def _expand_treatments (self , X = None , * Ts ):
185233 """
186234 Given a set of features and treatments, return possibly modified features and treatments.
@@ -303,6 +351,101 @@ def marginal_effect_inference(self, T, X=None):
303351 """
304352 pass
305353
354+ @_defer_to_inference
355+ def ate_interval (self , X = None , * , T0 , T1 , alpha = 0.1 ):
356+ """ Confidence intervals for the quantity :math:`E_X[\\ tau(X, T0, T1)]` produced
357+ by the model. Available only when ``inference`` is not ``None``, when
358+ calling the fit method.
359+
360+ Parameters
361+ ----------
362+ X: optional (m, d_x) matrix
363+ Features for each sample
364+ T0: optional (m, d_t) matrix or vector of length m (Default=0)
365+ Base treatments for each sample
366+ T1: optional (m, d_t) matrix or vector of length m (Default=1)
367+ Target treatments for each sample
368+ alpha: optional float in [0, 1] (Default=0.1)
369+ The overall level of confidence of the reported interval.
370+ The alpha/2, 1-alpha/2 confidence interval is reported.
371+
372+ Returns
373+ -------
374+ lower, upper : tuple(type of :meth:`ate(X, T0, T1)<ate>`, type of :meth:`ate(X, T0, T1))<ate>` )
375+ The lower and the upper bounds of the confidence interval for each quantity.
376+ """
377+ pass
378+
379+ @_defer_to_inference
380+ def ate_inference (self , X = None , * , T0 , T1 ):
381+ """ Inference results for the quantity :math:`E_X[\\ tau(X, T0, T1)]` produced
382+ by the model. Available only when ``inference`` is not ``None``, when
383+ calling the fit method.
384+
385+ Parameters
386+ ----------
387+ X: optional (m, d_x) matrix
388+ Features for each sample
389+ T0: optional (m, d_t) matrix or vector of length m (Default=0)
390+ Base treatments for each sample
391+ T1: optional (m, d_t) matrix or vector of length m (Default=1)
392+ Target treatments for each sample
393+
394+ Returns
395+ -------
396+ PopulationSummaryResults: object
397+ The inference results instance contains prediction and prediction standard error and
398+ can on demand calculate confidence interval, z statistic and p value. It can also output
399+ a dataframe summary of these inference results.
400+ """
401+ pass
402+
403+ @_defer_to_inference
404+ def marginal_ate_interval (self , T , X = None , * , alpha = 0.1 ):
405+ """ Confidence intervals for the quantities :math:`E_{T,X}[\\ partial \\ tau(T, X)]` produced
406+ by the model. Available only when ``inference`` is not ``None``, when
407+ calling the fit method.
408+
409+ Parameters
410+ ----------
411+ T: (m, d_t) matrix
412+ Base treatments for each sample
413+ X: optional (m, d_x) matrix or None (Default=None)
414+ Features for each sample
415+ alpha: optional float in [0, 1] (Default=0.1)
416+ The overall level of confidence of the reported interval.
417+ The alpha/2, 1-alpha/2 confidence interval is reported.
418+
419+ Returns
420+ -------
421+ lower, upper : tuple(type of :meth:`marginal_ate(T, X)<marginal_ate>`, \
422+ type of :meth:`marginal_ate(T, X)<marginal_ate>` )
423+ The lower and the upper bounds of the confidence interval for each quantity.
424+ """
425+ pass
426+
427+ @_defer_to_inference
428+ def marginal_ate_inference (self , T , X = None ):
429+ """ Inference results for the quantities :math:`E_{T,X}[\\ partial \\ tau(T, X)]` produced
430+ by the model. Available only when ``inference`` is not ``None``, when
431+ calling the fit method.
432+
433+ Parameters
434+ ----------
435+ T: (m, d_t) matrix
436+ Base treatments for each sample
437+ X: optional (m, d_x) matrix or None (Default=None)
438+ Features for each sample
439+
440+ Returns
441+ -------
442+ PopulationSummaryResults: object
443+ The inference results instance contains prediction and prediction standard error and
444+ can on demand calculate confidence interval, z statistic and p value. It can also output
445+ a dataframe summary of these inference results.
446+ """
447+ pass
448+
306449
307450class LinearCateEstimator (BaseCateEstimator ):
308451 """Base class for all CATE estimators with linear treatment effects in this package."""
@@ -457,6 +600,79 @@ def const_marginal_effect_inference(self, X=None):
457600 """
458601 pass
459602
603+ def const_marginal_ate (self , X = None ):
604+ """
605+ Calculate the average constant marginal CATE :math:`E_X[\\ theta(X)]`.
606+
607+ Parameters
608+ ----------
609+ X: optional (m, d_x) matrix or None (Default=None)
610+ Features for each sample.
611+
612+ Returns
613+ -------
614+ theta: (d_y, d_t) matrix
615+ Average constant marginal CATE of each treatment on each outcome.
616+ Note that when Y or T is a vector rather than a 2-dimensional array,
617+ the corresponding singleton dimensions in the output will be collapsed
618+ (e.g. if both are vectors, then the output of this method will be a scalar)
619+ """
620+ return np .mean (self .const_marginal_effect (X = X ), axis = 0 )
621+
622+ @BaseCateEstimator ._defer_to_inference
623+ def const_marginal_ate_interval (self , X = None , * , alpha = 0.1 ):
624+ """ Confidence intervals for the quantities :math:`E_X[\\ theta(X)]` produced
625+ by the model. Available only when ``inference`` is not ``None``, when
626+ calling the fit method.
627+
628+ Parameters
629+ ----------
630+ X: optional (m, d_x) matrix or None (Default=None)
631+ Features for each sample
632+ alpha: optional float in [0, 1] (Default=0.1)
633+ The overall level of confidence of the reported interval.
634+ The alpha/2, 1-alpha/2 confidence interval is reported.
635+
636+ Returns
637+ -------
638+ lower, upper : tuple(type of :meth:`const_marginal_ate(X)<const_marginal_ate>` ,\
639+ type of :meth:`const_marginal_ate(X)<const_marginal_ate>` )
640+ The lower and the upper bounds of the confidence interval for each quantity.
641+ """
642+ pass
643+
644+ @BaseCateEstimator ._defer_to_inference
645+ def const_marginal_ate_inference (self , X = None ):
646+ """ Inference results for the quantities :math:`E_X[\\ theta(X)]` produced
647+ by the model. Available only when ``inference`` is not ``None``, when
648+ calling the fit method.
649+
650+ Parameters
651+ ----------
652+ X: optional (m, d_x) matrix or None (Default=None)
653+ Features for each sample
654+
655+ Returns
656+ -------
657+ PopulationSummaryResults: object
658+ The inference results instance contains prediction and prediction standard error and
659+ can on demand calculate confidence interval, z statistic and p value. It can also output
660+ a dataframe summary of these inference results.
661+ """
662+ pass
663+
664+ def marginal_ate (self , T , X = None ):
665+ return self .const_marginal_ate (X = X )
666+ marginal_ate .__doc__ = BaseCateEstimator .marginal_ate .__doc__
667+
668+ def marginal_ate_interval (self , T , X = None , * , alpha = 0.1 ):
669+ return self .const_marginal_ate_interval (X = X , alpha = alpha )
670+ marginal_ate_interval .__doc__ = BaseCateEstimator .marginal_ate_interval .__doc__
671+
672+ def marginal_ate_inference (self , T , X = None ):
673+ return self .const_marginal_ate_inference (X = X )
674+ marginal_ate_inference .__doc__ = BaseCateEstimator .marginal_ate_inference .__doc__
675+
460676 def shap_values (self , X , * , feature_names = None , treatment_names = None , output_names = None , background_samples = 100 ):
461677 """ Shap value for the final stage models (const_marginal_effect)
462678
@@ -524,6 +740,18 @@ def effect(self, X=None, *, T0=0, T1=1):
524740 return super ().effect (X , T0 = T0 , T1 = T1 )
525741 effect .__doc__ = BaseCateEstimator .effect .__doc__
526742
743+ def ate (self , X = None , * , T0 = 0 , T1 = 1 ):
744+ return super ().ate (X = X , T0 = T0 , T1 = T1 )
745+ ate .__doc__ = BaseCateEstimator .ate .__doc__
746+
747+ def ate_interval (self , X = None , * , T0 = 0 , T1 = 1 , alpha = 0.1 ):
748+ return super ().ate_interval (X = X , T0 = T0 , T1 = T1 , alpha = alpha )
749+ ate_interval .__doc__ = BaseCateEstimator .ate_interval .__doc__
750+
751+ def ate_inference (self , X = None , * , T0 = 0 , T1 = 1 ):
752+ return super ().ate_inference (X = X , T0 = T0 , T1 = T1 )
753+ ate_inference .__doc__ = BaseCateEstimator .ate_inference .__doc__
754+
527755
528756class LinearModelFinalCateEstimatorMixin (BaseCateEstimator ):
529757 """
0 commit comments