Skip to content

Commit 9a96875

Browse files
authored
Refitting model_final and nuisance averaging (#360)
* Support refitting only final model in DML after changing estimator parameters * Add support for monte carlo nuisance estimation, with multiple k-fold draws. * added rlearner residuals_ property that returns fitted residuals for training data (fixes #350) * fixed flaky cate interpreter test * added refit example in the dml notebook
1 parent ce2f2b5 commit 9a96875

32 files changed

+2484
-744
lines changed

econml/_cate_estimator.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,11 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
529529
"""
530530
Base class for models where the final stage is a linear model.
531531
532-
Subclasses must expose a ``model_final`` attribute containing the model's
533-
final stage model.
532+
Such an estimator must implement a :attr:`model_final_` attribute that points
533+
to the fitted final :class:`.StatsModelsLinearRegression` object that
534+
represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
535+
to the fitted featurizer and :attr:`bias_part_of_coef` that designates
536+
if the intercept is the first element of the :attr:`model_final_` coefficient.
534537
535538
Attributes
536539
----------
@@ -544,7 +547,9 @@ def _get_inference_options(self):
544547
options.update(auto=LinearModelFinalInference)
545548
return options
546549

547-
bias_part_of_coef = False
550+
@property
551+
def bias_part_of_coef(self):
552+
return False
548553

549554
@property
550555
def coef_(self):
@@ -561,9 +566,9 @@ def coef_(self):
561566
a vector and not a 2D array. For binary treatment the n_t dimension is
562567
also omitted.
563568
"""
564-
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
569+
return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_,
565570
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
566-
self.fit_cate_intercept)[0]
571+
self.fit_cate_intercept_)[0]
567572

568573
@property
569574
def intercept_(self):
@@ -578,11 +583,11 @@ def intercept_(self):
578583
a vector and not a 2D array. For binary treatment the n_t dimension is
579584
also omitted.
580585
"""
581-
if not self.fit_cate_intercept:
586+
if not self.fit_cate_intercept_:
582587
raise AttributeError("No intercept was fitted!")
583-
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
588+
return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_,
584589
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
585-
self.fit_cate_intercept)[1]
590+
self.fit_cate_intercept_)[1]
586591

587592
@BaseCateEstimator._defer_to_inference
588593
def coef__interval(self, *, alpha=0.1):
@@ -718,11 +723,11 @@ def summary(self, alpha=0.1, value=0, decimals=3, feature_names=None, treatment_
718723
return smry
719724

720725
def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
721-
if hasattr(self, "featurizer") and self.featurizer is not None:
722-
X = self.featurizer.transform(X)
726+
if hasattr(self, "featurizer_") and self.featurizer_ is not None:
727+
X = self.featurizer_.transform(X)
723728
feature_names = self.cate_feature_names(feature_names)
724-
return _shap_explain_joint_linear_model_cate(self.model_final, X, self._d_t, self._d_y,
725-
self.fit_cate_intercept,
729+
return _shap_explain_joint_linear_model_cate(self.model_final_, X, self._d_t, self._d_y,
730+
self.bias_part_of_coef,
726731
feature_names=feature_names, treatment_names=treatment_names,
727732
output_names=output_names,
728733
input_names=self._input_names,
@@ -736,9 +741,11 @@ class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
736741
Mixin class that offers `inference='statsmodels'` options to the CATE estimator
737742
that inherits it.
738743
739-
Such an estimator must implement a :attr:`model_final` attribute that points
744+
Such an estimator must implement a :attr:`model_final_` attribute that points
740745
to the fitted final :class:`.StatsModelsLinearRegression` object that
741-
represents the fitted CATE model.
746+
represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
747+
to the fitted featurizer and :attr:`bias_part_of_coef` that designates
748+
if the intercept is the first element of the :attr:`model_final_` coefficient.
742749
"""
743750

744751
def _get_inference_options(self):
@@ -771,7 +778,7 @@ def _get_inference_options(self):
771778

772779
@property
773780
def feature_importances_(self):
774-
return self.model_final.feature_importances_
781+
return self.model_final_.feature_importances_
775782

776783

777784
class LinearModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):
@@ -822,7 +829,7 @@ def intercept_(self, T):
822829
-------
823830
intercept: float or (n_y,) array like
824831
"""
825-
if not self.fit_cate_intercept:
832+
if not self.fit_cate_intercept_:
826833
raise AttributeError("No intercept was fitted!")
827834
_, T = self._expand_treatments(None, T)
828835
ind = inverse_onehot(T).item() - 1
@@ -980,7 +987,7 @@ class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscret
980987
Mixin class that offers `inference='statsmodels'` options to the CATE estimator
981988
that inherits it.
982989
983-
Such an estimator must implement a :attr:`model_final` attribute that points
990+
Such an estimator must implement a :attr:`model_final_` attribute that points
984991
to a :class:`.StatsModelsLinearRegression` object that is cloned to fit
985992
each discrete treatment target CATE model and a :attr:`fitted_models_final` attribute
986993
that returns the list of fitted final models that represent the CATE for each categorical treatment.

0 commit comments

Comments
 (0)