@@ -400,6 +400,14 @@ def __init__(self, *,
400400 categories = categories ,
401401 random_state = random_state )
402402
403+ def _get_inference_options (self ):
404+ options = super ()._get_inference_options ()
405+ if not self .multitask_model_final :
406+ options .update (auto = GenericModelFinalInferenceDiscrete )
407+ else :
408+ options .update (auto = lambda : None )
409+ return options
410+
403411 def _gen_ortho_learner_model_nuisance (self ):
404412 if self .model_propensity == 'auto' :
405413 model_propensity = LogisticRegressionCV (cv = 3 , solver = 'lbfgs' , multi_class = 'auto' ,
@@ -426,7 +434,7 @@ def _gen_ortho_learner_model_final(self):
426434 @_deprecate_positional ("X and W should be passed by keyword only. In a future release "
427435 "we will disallow passing X and W by position." , ['X' , 'W' ])
428436 def fit (self , Y , T , X = None , W = None , * , sample_weight = None , sample_var = None , groups = None ,
429- cache_values = False , inference = None ):
437+ cache_values = False , inference = 'auto' ):
430438 """
431439 Estimate the counterfactual model from data, i.e. estimates function :math:`\\ theta(\\ cdot)`.
432440
@@ -463,6 +471,10 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
463471 sample_weight = sample_weight , sample_var = sample_var , groups = groups ,
464472 cache_values = cache_values , inference = inference )
465473
474+ def refit_final (self , * , inference = 'auto' ):
475+ return super ().refit_final (inference = inference )
476+ refit_final .__doc__ = _OrthoLearner .refit_final .__doc__
477+
466478 def score (self , Y , T , X = None , W = None ):
467479 """
468480 Score the fitted CATE model on a new data set. Generates nuisance parameters
@@ -851,10 +863,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
851863 sample_weight = sample_weight , sample_var = sample_var , groups = groups ,
852864 cache_values = cache_values , inference = inference )
853865
854- def refit_final (self , * , inference = 'auto' ):
855- return super ().refit_final (inference = inference )
856- refit_final .__doc__ = _OrthoLearner .refit_final .__doc__
857-
858866 @property
859867 def fit_cate_intercept_ (self ):
860868 return self .model_final_ .fit_intercept
@@ -1151,10 +1159,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
11511159 sample_weight = sample_weight , sample_var = None , groups = groups ,
11521160 cache_values = cache_values , inference = inference )
11531161
1154- def refit_final (self , * , inference = 'auto' ):
1155- return super ().refit_final (inference = inference )
1156- refit_final .__doc__ = _OrthoLearner .refit_final .__doc__
1157-
11581162 @property
11591163 def fit_cate_intercept_ (self ):
11601164 return self .model_final_ .fit_intercept
@@ -1465,10 +1469,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
14651469 sample_weight = sample_weight , sample_var = None , groups = groups ,
14661470 cache_values = cache_values , inference = inference )
14671471
1468- def refit_final (self , * , inference = 'auto' ):
1469- return super ().refit_final (inference = inference )
1470- refit_final .__doc__ = _OrthoLearner .refit_final .__doc__
1471-
14721472 def multitask_model_cate (self ):
14731473 # Replacing to remove docstring
14741474 super ().multitask_model_cate ()
0 commit comments