Skip to content

Commit f0b0e5b

Browse files
authored
Enabling summary() even when inference not available (#363)
* enable summary inference with stderr = None
1 parent 1dd73c7 commit f0b0e5b

File tree

4 files changed

+309
-102
lines changed

4 files changed

+309
-102
lines changed

econml/drlearner.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,14 @@ def __init__(self, *,
400400
categories=categories,
401401
random_state=random_state)
402402

403+
def _get_inference_options(self):
404+
options = super()._get_inference_options()
405+
if not self.multitask_model_final:
406+
options.update(auto=GenericModelFinalInferenceDiscrete)
407+
else:
408+
options.update(auto=lambda: None)
409+
return options
410+
403411
def _gen_ortho_learner_model_nuisance(self):
404412
if self.model_propensity == 'auto':
405413
model_propensity = LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto',
@@ -426,7 +434,7 @@ def _gen_ortho_learner_model_final(self):
426434
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
427435
"we will disallow passing X and W by position.", ['X', 'W'])
428436
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
429-
cache_values=False, inference=None):
437+
cache_values=False, inference='auto'):
430438
"""
431439
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
432440
@@ -463,6 +471,10 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
463471
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
464472
cache_values=cache_values, inference=inference)
465473

474+
def refit_final(self, *, inference='auto'):
475+
return super().refit_final(inference=inference)
476+
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__
477+
466478
def score(self, Y, T, X=None, W=None):
467479
"""
468480
Score the fitted CATE model on a new data set. Generates nuisance parameters
@@ -851,10 +863,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
851863
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
852864
cache_values=cache_values, inference=inference)
853865

854-
def refit_final(self, *, inference='auto'):
855-
return super().refit_final(inference=inference)
856-
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__
857-
858866
@property
859867
def fit_cate_intercept_(self):
860868
return self.model_final_.fit_intercept
@@ -1151,10 +1159,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
11511159
sample_weight=sample_weight, sample_var=None, groups=groups,
11521160
cache_values=cache_values, inference=inference)
11531161

1154-
def refit_final(self, *, inference='auto'):
1155-
return super().refit_final(inference=inference)
1156-
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__
1157-
11581162
@property
11591163
def fit_cate_intercept_(self):
11601164
return self.model_final_.fit_intercept
@@ -1465,10 +1469,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
14651469
sample_weight=sample_weight, sample_var=None, groups=groups,
14661470
cache_values=cache_values, inference=inference)
14671471

1468-
def refit_final(self, *, inference='auto'):
1469-
return super().refit_final(inference=inference)
1470-
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__
1471-
14721472
def multitask_model_cate(self):
14731473
# Replacing to remove docstring
14741474
super().multitask_model_cate()

0 commit comments

Comments
 (0)