@@ -529,8 +529,11 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
529529 """
530530 Base class for models where the final stage is a linear model.
531531
532- Subclasses must expose a ``model_final`` attribute containing the model's
533- final stage model.
532+ Such an estimator must implement a :attr:`model_final_` attribute that points
533+ to the fitted final :class:`.StatsModelsLinearRegression` object that
534+ represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
535+ to the fitted featurizer and :attr:`bias_part_of_coef` that designates
536+ if the intercept is the first element of the :attr:`model_final_` coefficient.
534537
535538 Attributes
536539 ----------
@@ -544,7 +547,9 @@ def _get_inference_options(self):
544547 options .update (auto = LinearModelFinalInference )
545548 return options
546549
547- bias_part_of_coef = False
550+ @property
551+ def bias_part_of_coef (self ):
552+ return False
548553
549554 @property
550555 def coef_ (self ):
@@ -561,9 +566,9 @@ def coef_(self):
561566 a vector and not a 2D array. For binary treatment the n_t dimension is
562567 also omitted.
563568 """
564- return parse_final_model_params (self .model_final .coef_ , self .model_final .intercept_ ,
569+ return parse_final_model_params (self .model_final_ .coef_ , self .model_final_ .intercept_ ,
565570 self ._d_y , self ._d_t , self ._d_t_in , self .bias_part_of_coef ,
566- self .fit_cate_intercept )[0 ]
571+ self .fit_cate_intercept_ )[0 ]
567572
568573 @property
569574 def intercept_ (self ):
@@ -578,11 +583,11 @@ def intercept_(self):
578583 a vector and not a 2D array. For binary treatment the n_t dimension is
579584 also omitted.
580585 """
581- if not self .fit_cate_intercept :
586+ if not self .fit_cate_intercept_ :
582587 raise AttributeError ("No intercept was fitted!" )
583- return parse_final_model_params (self .model_final .coef_ , self .model_final .intercept_ ,
588+ return parse_final_model_params (self .model_final_ .coef_ , self .model_final_ .intercept_ ,
584589 self ._d_y , self ._d_t , self ._d_t_in , self .bias_part_of_coef ,
585- self .fit_cate_intercept )[1 ]
590+ self .fit_cate_intercept_ )[1 ]
586591
587592 @BaseCateEstimator ._defer_to_inference
588593 def coef__interval (self , * , alpha = 0.1 ):
@@ -718,11 +723,11 @@ def summary(self, alpha=0.1, value=0, decimals=3, feature_names=None, treatment_
718723 return smry
719724
720725 def shap_values (self , X , * , feature_names = None , treatment_names = None , output_names = None , background_samples = 100 ):
721- if hasattr (self , "featurizer " ) and self .featurizer is not None :
722- X = self .featurizer .transform (X )
726+ if hasattr (self , "featurizer_ " ) and self .featurizer_ is not None :
727+ X = self .featurizer_ .transform (X )
723728 feature_names = self .cate_feature_names (feature_names )
724- return _shap_explain_joint_linear_model_cate (self .model_final , X , self ._d_t , self ._d_y ,
725- self .fit_cate_intercept ,
729+ return _shap_explain_joint_linear_model_cate (self .model_final_ , X , self ._d_t , self ._d_y ,
730+ self .bias_part_of_coef ,
726731 feature_names = feature_names , treatment_names = treatment_names ,
727732 output_names = output_names ,
728733 input_names = self ._input_names ,
@@ -736,9 +741,11 @@ class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
736741 Mixin class that offers `inference='statsmodels'` options to the CATE estimator
737742 that inherits it.
738743
739- Such an estimator must implement a :attr:`model_final ` attribute that points
744+ Such an estimator must implement a :attr:`model_final_ ` attribute that points
740745 to the fitted final :class:`.StatsModelsLinearRegression` object that
741- represents the fitted CATE model.
746+ represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
747+ to the fitted featurizer and :attr:`bias_part_of_coef` that designates
748+ if the intercept is the first element of the :attr:`model_final_` coefficient.
742749 """
743750
744751 def _get_inference_options (self ):
@@ -771,7 +778,7 @@ def _get_inference_options(self):
771778
772779 @property
773780 def feature_importances_ (self ):
774- return self .model_final .feature_importances_
781+ return self .model_final_ .feature_importances_
775782
776783
777784class LinearModelFinalCateEstimatorDiscreteMixin (BaseCateEstimator ):
@@ -822,7 +829,7 @@ def intercept_(self, T):
822829 -------
823830 intercept: float or (n_y,) array like
824831 """
825- if not self .fit_cate_intercept :
832+ if not self .fit_cate_intercept_ :
826833 raise AttributeError ("No intercept was fitted!" )
827834 _ , T = self ._expand_treatments (None , T )
828835 ind = inverse_onehot (T ).item () - 1
@@ -980,7 +987,7 @@ class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscret
980987 Mixin class that offers `inference='statsmodels'` options to the CATE estimator
981988 that inherits it.
982989
983- Such an estimator must implement a :attr:`model_final ` attribute that points
990+ Such an estimator must implement a :attr:`model_final_ ` attribute that points
984991 to a :class:`.StatsModelsLinearRegression` object that is cloned to fit
985992 each discrete treatment target CATE model and a :attr:`fitted_models_final` attribute
986993 that returns the list of fitted final models that represent the CATE for each categorical treatment.
0 commit comments