@@ -147,7 +147,7 @@ def const_marginal_effect_inference(self, X):
147147 warn ("Final model doesn't have a `prediction_stderr` method, "
148148 "only point estimates will be returned." )
149149 return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = pred ,
150- pred_stderr = pred_stderr , inf_type = 'effect' ,
150+ pred_stderr = pred_stderr , mean_pred_stderr = None , inf_type = 'effect' ,
151151 feature_names = self ._est .cate_feature_names (),
152152 output_names = self ._est .cate_output_names (),
153153 treatment_names = self ._est .cate_treatment_names ())
@@ -193,9 +193,10 @@ def effect_inference(self, X, *, T0, T1):
193193 e_pred = np .einsum (einsum_str , cme_pred , dT )
194194 e_stderr = np .einsum (einsum_str , cme_stderr , np .abs (dT )) if cme_stderr is not None else None
195195 d_y = self ._d_y [0 ] if self ._d_y else 1
196+
196197 # d_t=None here since we measure the effect across all Ts
197198 return NormalInferenceResults (d_t = None , d_y = d_y , pred = e_pred ,
198- pred_stderr = e_stderr , inf_type = 'effect' ,
199+ pred_stderr = e_stderr , mean_pred_stderr = None , inf_type = 'effect' ,
199200 feature_names = self ._est .cate_feature_names (),
200201 output_names = self ._est .cate_output_names ())
201202
@@ -240,15 +241,38 @@ def effect_inference(self, X, *, T0, T1):
240241 X = np .ones ((T0 .shape [0 ], 1 ))
241242 elif self .featurizer is not None :
242243 X = self .featurizer .transform (X )
243- e_pred = self ._predict (cross_product (X , T1 - T0 ))
244- e_stderr = self ._prediction_stderr (cross_product (X , T1 - T0 ))
244+ XT = cross_product (X , T1 - T0 )
245+ e_pred = self ._predict (XT )
246+ e_stderr = self ._prediction_stderr (XT )
245247 d_y = self ._d_y [0 ] if self ._d_y else 1
248+
249+ mean_XT = XT .mean (axis = 0 , keepdims = True )
250+ mean_pred_stderr = self ._prediction_stderr (mean_XT ) # shape[0] will always be 1 here
251+ # squeeze the first axis
252+ mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 ) if mean_pred_stderr is not None else None
246253 # d_t=None here since we measure the effect across all Ts
247254 return NormalInferenceResults (d_t = None , d_y = d_y , pred = e_pred ,
248- pred_stderr = e_stderr , inf_type = 'effect' ,
255+ pred_stderr = e_stderr , mean_pred_stderr = mean_pred_stderr , inf_type = 'effect' ,
249256 feature_names = self ._est .cate_feature_names (),
250257 output_names = self ._est .cate_output_names ())
251258
259+ def const_marginal_effect_inference (self , X ):
260+ inf_res = super ().const_marginal_effect_inference (X )
261+
262+ # set the mean_pred_stderr
263+ if X is None :
264+ X = np .ones ((1 , 1 ))
265+ elif self .featurizer is not None :
266+ X = self .featurizer .transform (X )
267+ X_mean , T_mean = broadcast_unit_treatments (X .mean (axis = 0 ).reshape (1 , - 1 ), self .d_t )
268+ mean_XT = cross_product (X_mean , T_mean )
269+ mean_pred_stderr = self ._prediction_stderr (mean_XT )
270+ if mean_pred_stderr is not None :
271+ mean_pred_stderr = reshape_treatmentwise_effects (mean_pred_stderr ,
272+ self ._d_t , self ._d_y ) # shape[0] will always be 1 here
273+ inf_res .mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 )
274+ return inf_res
275+
252276 def coef__interval (self , * , alpha = 0.1 ):
253277 lo , hi = self .model_final .coef__interval (alpha )
254278 lo_int , hi_int = self .model_final .intercept__interval (alpha )
@@ -285,6 +309,7 @@ def coef__inference(self):
285309 fname_transformer = self ._est .cate_feature_names
286310
287311 return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = coef , pred_stderr = coef_stderr ,
312+ mean_pred_stderr = None ,
288313 inf_type = 'coefficient' , fname_transformer = fname_transformer ,
289314 feature_names = self ._est .cate_feature_names (),
290315 output_names = self ._est .cate_output_names (),
@@ -323,6 +348,7 @@ def intercept__inference(self):
323348 intercept_stderr = None
324349
325350 return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = intercept , pred_stderr = intercept_stderr ,
351+ mean_pred_stderr = None ,
326352 inf_type = 'intercept' ,
327353 feature_names = self ._est .cate_feature_names (),
328354 output_names = self ._est .cate_output_names (),
@@ -380,11 +406,7 @@ def fit(self, estimator, *args, **kwargs):
380406 self .fit_cate_intercept = estimator .fit_cate_intercept
381407
382408 def const_marginal_effect_interval (self , X , * , alpha = 0.1 ):
383- if (X is not None ) and (self .featurizer is not None ):
384- X = self .featurizer .transform (X )
385- preds = np .array ([tuple (map (lambda x : x .reshape ((- 1 ,) + self ._d_y ), mdl .predict_interval (X , alpha = alpha )))
386- for mdl in self .fitted_models_final ])
387- return tuple (np .moveaxis (preds , [0 , 1 ], [- 1 , 0 ])) # send treatment to the end, pull bounds to the front
409+ return self .const_marginal_effect_inference (X ).conf_int (alpha = alpha )
388410
389411 def const_marginal_effect_inference (self , X ):
390412 if (X is not None ) and (self .featurizer is not None ):
@@ -401,22 +423,14 @@ def const_marginal_effect_inference(self, X):
401423 "Only point estimates will be available." )
402424 pred_stderr = None
403425 return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = pred ,
404- pred_stderr = pred_stderr , inf_type = 'effect' ,
426+ pred_stderr = pred_stderr , mean_pred_stderr = None ,
427+ inf_type = 'effect' ,
405428 feature_names = self ._est .cate_feature_names (),
406429 output_names = self ._est .cate_output_names (),
407430 treatment_names = self ._est .cate_treatment_names ())
408431
409432 def effect_interval (self , X , * , T0 , T1 , alpha = 0.1 ):
410- X , T0 , T1 = self ._est ._expand_treatments (X , T0 , T1 )
411- if np .any (np .any (T0 > 0 , axis = 1 )):
412- raise AttributeError ("Can only calculate intervals of effects with respect to baseline treatment!" )
413- ind = inverse_onehot (T1 )
414- lower , upper = self .const_marginal_effect_interval (X , alpha = alpha )
415- lower = np .concatenate ([np .zeros (lower .shape [0 :- 1 ] + (1 ,)), lower ], - 1 )
416- upper = np .concatenate ([np .zeros (upper .shape [0 :- 1 ] + (1 ,)), upper ], - 1 )
417- if X is None : # Then const_marginal_effect_interval will return a single row
418- lower , upper = np .repeat (lower , T0 .shape [0 ], axis = 0 ), np .repeat (upper , T0 .shape [0 ], axis = 0 )
419- return lower [np .arange (T0 .shape [0 ]), ..., ind ], upper [np .arange (T0 .shape [0 ]), ..., ind ]
433+ return self .effect_inference (X , T0 = T0 , T1 = T1 ).conf_int (alpha = alpha )
420434
421435 def effect_inference (self , X , * , T0 , T1 ):
422436 X , T0 , T1 = self ._est ._expand_treatments (X , T0 , T1 )
@@ -434,9 +448,10 @@ def effect_inference(self, X, *, T0, T1):
434448 pred_stderr = np .repeat (pred_stderr , T0 .shape [0 ], axis = 0 ) if pred_stderr is not None else None
435449 pred = pred [np .arange (T0 .shape [0 ]), ..., ind ]
436450 pred_stderr = pred_stderr [np .arange (T0 .shape [0 ]), ..., ind ] if pred_stderr is not None else None
451+
437452 # d_t=None here since we measure the effect across all Ts
438453 return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = pred ,
439- pred_stderr = pred_stderr ,
454+ pred_stderr = pred_stderr , mean_pred_stderr = None ,
440455 inf_type = 'effect' ,
441456 feature_names = self ._est .cate_feature_names (),
442457 output_names = self ._est .cate_output_names ())
@@ -449,6 +464,33 @@ class LinearModelFinalInferenceDiscrete(GenericModelFinalInferenceDiscrete):
449464 based on the corresponding methods of the underlying model_final estimator.
450465 """
451466
467+ def const_marginal_effect_inference (self , X ):
468+ res_inf = super ().const_marginal_effect_inference (X )
469+
470+ # set the mean_pred_stderr
471+ if (X is not None ) and (self .featurizer is not None ):
472+ X = self .featurizer .transform (X )
473+
474+ if hasattr (self .fitted_models_final [0 ], 'prediction_stderr' ):
475+ mean_X = X .mean (axis = 0 ).reshape (1 , - 1 ) if X is not None else None
476+ mean_pred_stderr = np .moveaxis (np .array ([mdl .prediction_stderr (mean_X ).reshape ((- 1 ,) + self ._d_y )
477+ for mdl in self .fitted_models_final ]),
478+ 0 , - 1 ) # shape[0] will always be 1 here
479+ res_inf .mean_pred_stderr = np .squeeze (mean_pred_stderr , axis = 0 )
480+ return res_inf
481+
482+ def effect_inference (self , X , * , T0 , T1 ):
483+ res_inf = super ().effect_inference (X , T0 = T0 , T1 = T1 )
484+
485+ # replace the mean_pred_stderr if T1 and T0 is a constant or a constant of vector
486+ _ , _ , T1 = self ._est ._expand_treatments (X , T0 , T1 )
487+ ind = inverse_onehot (T1 )
488+ if len (set (ind )) == 1 :
489+ unique_ind = ind [0 ] - 1
490+ mean_pred_stderr = self .const_marginal_effect_inference (X ).mean_pred_stderr [..., unique_ind ]
491+ res_inf .mean_pred_stderr = mean_pred_stderr
492+ return res_inf
493+
452494 def coef__interval (self , T , * , alpha = 0.1 ):
453495 _ , T = self ._est ._expand_treatments (None , T )
454496 ind = inverse_onehot (T ).item () - 1
@@ -472,8 +514,10 @@ def coef__inference(self, T):
472514 fname_transformer = None
473515 if hasattr (self ._est , 'cate_feature_names' ) and callable (self ._est .cate_feature_names ):
474516 fname_transformer = self ._est .cate_feature_names
517+
475518 # d_t=None here since we measure the effect across all Ts
476519 return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = coef , pred_stderr = coef_stderr ,
520+ mean_pred_stderr = None ,
477521 inf_type = 'coefficient' , fname_transformer = fname_transformer ,
478522 feature_names = self ._est .cate_feature_names (),
479523 output_names = self ._est .cate_output_names ())
@@ -500,7 +544,7 @@ def intercept__inference(self, T):
500544 intercept_stderr = None
501545 # d_t=None here since we measure the effect across all Ts
502546 return NormalInferenceResults (d_t = None , d_y = self .d_y , pred = self .fitted_models_final [ind ].intercept_ ,
503- pred_stderr = intercept_stderr ,
547+ pred_stderr = intercept_stderr , mean_pred_stderr = None ,
504548 inf_type = 'intercept' ,
505549 feature_names = self ._est .cate_feature_names (),
506550 output_names = self ._est .cate_output_names ())
@@ -748,7 +792,6 @@ def summary_frame(self, alpha=0.1, value=0, decimals=3,
748792
749793 elif self .inf_type == 'intercept' :
750794 res .index = res .index .set_levels (['cate_intercept' ], level = "X" )
751-
752795 if self ._d_t == 1 :
753796 res .index = res .index .droplevel ("T" )
754797 if self .d_y == 1 :
@@ -786,6 +829,7 @@ def population_summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_n
786829 output_names = self .output_names if output_names is None else output_names
787830 if self .inf_type == 'effect' :
788831 return PopulationSummaryResults (pred = self .point_estimate , pred_stderr = self .stderr ,
832+ mean_pred_stderr = None ,
789833 d_t = self .d_t , d_y = self .d_y ,
790834 alpha = alpha , value = value , decimals = decimals , tol = tol ,
791835 output_names = output_names , treatment_names = treatment_names )
@@ -839,17 +883,22 @@ class NormalInferenceResults(InferenceResults):
839883 Note that when Y or T is a vector rather than a 2-dimensional array,
840884 the corresponding singleton dimensions should be collapsed
841885 (e.g. if both are vectors, then the input of this argument will also be a vector)
886+ mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
887+ The standard error of the mean point estimate, this is derived from coefficient stderr when final
888+ stage is linear model, otherwise it's None.
889+ This is the exact standard error of the mean, which is not conservative.
842890 inf_type: string
843891 The type of inference result.
844892 It could be either 'effect', 'coefficient' or 'intercept'.
845893 fname_transformer: None or predefined function
846894 The transform function to get the corresponding feature names from featurizer
847895 """
848896
849- def __init__ (self , d_t , d_y , pred , pred_stderr , inf_type , fname_transformer = None ,
897+ def __init__ (self , d_t , d_y , pred , pred_stderr , mean_pred_stderr , inf_type , fname_transformer = None ,
850898 feature_names = None , output_names = None , treatment_names = None ):
851899 self .pred_stderr = np .copy (pred_stderr ) if pred_stderr is not None and not np .isscalar (
852900 pred_stderr ) else pred_stderr
901+ self .mean_pred_stderr = mean_pred_stderr
853902 super ().__init__ (d_t , d_y , pred , inf_type , fname_transformer , feature_names , output_names , treatment_names )
854903
855904 @property
@@ -915,11 +964,20 @@ def pvalue(self, value=0):
915964 """
916965 return norm .sf (np .abs (self .zstat (value )), loc = 0 , scale = 1 ) * 2
917966
967+ def population_summary (self , alpha = 0.1 , value = 0 , decimals = 3 , tol = 0.001 , output_names = None , treatment_names = None ):
968+ pop_summ = super ().population_summary (alpha = alpha , value = value , decimals = decimals ,
969+ tol = tol , output_names = output_names , treatment_names = treatment_names )
970+ pop_summ .mean_pred_stderr = self .mean_pred_stderr
971+ return pop_summ
972+ population_summary .__doc__ = InferenceResults .population_summary .__doc__
973+
918974 def _expand_outputs (self , n_rows ):
919975 assert shape (self .pred )[0 ] == shape (self .pred_stderr )[0 ] == 1
920976 pred = np .repeat (self .pred , n_rows , axis = 0 )
921977 pred_stderr = np .repeat (self .pred_stderr , n_rows , axis = 0 ) if self .pred_stderr is not None else None
922- return NormalInferenceResults (self .d_t , self .d_y , pred , pred_stderr , self .inf_type ,
978+ return NormalInferenceResults (self .d_t , self .d_y , pred , pred_stderr ,
979+ self .mean_pred_stderr ,
980+ self .inf_type ,
923981 self .fname_transformer , self .feature_names ,
924982 self .output_names , self .treatment_names )
925983
@@ -1039,6 +1097,10 @@ class PopulationSummaryResults:
10391097 Note that when Y or T is a vector rather than a 2-dimensional array,
10401098 the corresponding singleton dimensions should be collapsed
10411099 (e.g. if both are vectors, then the input of this argument will also be a vector)
1100+ mean_pred_stderr: None or array-like or scaler, shape (d_y, d_t) or (d_y,)
1101+ The standard error of the mean point estimate, this is derived from coefficient stderr when final
1102+ stage is linear model, otherwise it's None.
1103+ This is the exact standard error of the mean, which is not conservative.
10421104 alpha: optional float in [0, 1] (default=0.1)
10431105 The overall level of confidence of the reported interval.
10441106 The alpha/2, 1-alpha/2 confidence interval is reported.
@@ -1055,10 +1117,11 @@ class PopulationSummaryResults:
10551117
10561118 """
10571119
1058- def __init__ (self , pred , pred_stderr , d_t , d_y , alpha , value , decimals , tol ,
1120+ def __init__ (self , pred , pred_stderr , mean_pred_stderr , d_t , d_y , alpha , value , decimals , tol ,
10591121 output_names = None , treatment_names = None ):
10601122 self .pred = pred
10611123 self .pred_stderr = pred_stderr
1124+ self .mean_pred_stderr = mean_pred_stderr
10621125 self .d_t = d_t
10631126 # For effect summaries, d_t is None, but the result arrays behave as if d_t=1
10641127 self ._d_t = d_t or 1
@@ -1106,7 +1169,9 @@ def stderr_mean(self):
11061169 the corresponding singleton dimensions in the output will be collapsed
11071170 (e.g. if both are vectors, then the output of this method will be a scalar)
11081171 """
1109- if self .pred_stderr is None :
1172+ if self .mean_pred_stderr is not None :
1173+ return self .mean_pred_stderr
1174+ elif self .pred_stderr is None :
11101175 raise AttributeError ("Only point estimates are available!" )
11111176 return np .sqrt (np .mean (self .pred_stderr ** 2 , axis = 0 ))
11121177
@@ -1312,13 +1377,13 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
13121377 self ._format_res (self .pvalue (value = value ), decimals ),
13131378 self ._format_res (self .conf_int_mean (alpha = alpha )[0 ], decimals ),
13141379 self ._format_res (self .conf_int_mean (alpha = alpha )[1 ], decimals )))
1315-
13161380 if treatment_names is None :
13171381 treatment_names = ['T' + str (i ) for i in range (self ._d_t )]
13181382 if output_names is None :
13191383 output_names = ['Y' + str (i ) for i in range (self .d_y )]
13201384
13211385 myheaders1 = ['mean_point' , 'stderr_mean' , 'zstat' , 'pvalue' , 'ci_mean_lower' , 'ci_mean_upper' ]
1386+
13221387 mystubs = self ._get_stub_names (self .d_y , self ._d_t , treatment_names , output_names )
13231388 title1 = "Uncertainty of Mean Point Estimate"
13241389
@@ -1331,13 +1396,12 @@ def _print(self, *, alpha=None, value=None, decimals=None, tol=None, output_name
13311396
13321397 smry = Summary ()
13331398 smry .add_table (res1 , myheaders1 , mystubs , title1 )
1334- if self .pred_stderr is not None :
1399+ if self .pred_stderr is not None and self . mean_pred_stderr is None :
13351400 text1 = "Note: The stderr_mean is a conservative upper bound."
13361401 smry .add_extra_txt ([text1 ])
13371402 smry .add_table (res2 , myheaders2 , mystubs , title2 )
13381403
13391404 if self .pred_stderr is not None :
1340-
13411405 # 3. Total Variance of Point Estimate
13421406 res3 = np .hstack ((self ._format_res (self .stderr_point , self .decimals ),
13431407 self ._format_res (self .conf_int_point (alpha = alpha , tol = tol )[0 ],
0 commit comments