Skip to content

Commit 2c140f2

Browse files
authored
Correct default value of alphas (#476)
The PopulationSummaryResults class should to default to self.alpha when the alpha parameter is not passed. This is not the case currently since some methods use the statement 'alpha = self.alpha if alpha is None else alpha', however alpha's default value in the methods is .1, thus self.alpha will be ignored even if the client does not pass the alpha parameter. Without the fix ate_interval(), marginal_ate_interval() and const_marginal_ate_interval() all ignore the alpha parameter.
1 parent 1bb4f3f commit 2c140f2

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

econml/inference/_inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,8 +1179,8 @@ class PopulationSummaryResults:
11791179
11801180
"""
11811181

1182-
def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha, value, decimals, tol,
1183-
output_names=None, treatment_names=None):
1182+
def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha=0.1,
1183+
value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
11841184
self.pred = pred
11851185
self.pred_stderr = pred_stderr
11861186
self.mean_pred_stderr = mean_pred_stderr
@@ -1237,13 +1237,13 @@ def stderr_mean(self):
12371237
raise AttributeError("Only point estimates are available!")
12381238
return np.sqrt(np.mean(self.pred_stderr**2, axis=0))
12391239

1240-
def zstat(self, *, value=0):
1240+
def zstat(self, *, value=None):
12411241
"""
12421242
Get the z statistic of the mean point estimate of each treatment on each outcome for sample X.
12431243
12441244
Parameters
12451245
----------
1246-
value: optinal float (default=0)
1246+
value: optional float (default=0)
12471247
The mean value of the metric you'd like to test under null hypothesis.
12481248
12491249
Returns
@@ -1258,13 +1258,13 @@ def zstat(self, *, value=0):
12581258
zstat = (self.mean_point - value) / self.stderr_mean
12591259
return zstat
12601260

1261-
def pvalue(self, *, value=0):
1261+
def pvalue(self, *, value=None):
12621262
"""
12631263
Get the p value of the z test of each treatment on each outcome for sample X.
12641264
12651265
Parameters
12661266
----------
1267-
value: optinal float (default=0)
1267+
value: optional float (default=0)
12681268
The mean value of the metric you'd like to test under null hypothesis.
12691269
12701270
Returns
@@ -1275,10 +1275,11 @@ def pvalue(self, *, value=0):
12751275
the corresponding singleton dimensions in the output will be collapsed
12761276
(e.g. if both are vectors, then the output of this method will be a scalar)
12771277
"""
1278+
value = self.value if value is None else value
12781279
pvalue = norm.sf(np.abs(self.zstat(value=value)), loc=0, scale=1) * 2
12791280
return pvalue
12801281

1281-
def conf_int_mean(self, *, alpha=.1):
1282+
def conf_int_mean(self, *, alpha=None):
12821283
"""
12831284
Get the confidence interval of the mean point estimate of each treatment on each outcome for sample X.
12841285
@@ -1323,7 +1324,7 @@ def std_point(self):
13231324
"""
13241325
return np.std(self.pred, axis=0)
13251326

1326-
def percentile_point(self, *, alpha=.1):
1327+
def percentile_point(self, *, alpha=None):
13271328
"""
13281329
Get the confidence interval of the point estimate of each treatment on each outcome for sample X.
13291330
@@ -1346,7 +1347,7 @@ def percentile_point(self, *, alpha=.1):
13461347
upper_percentile_point = np.percentile(self.pred, (1 - alpha / 2) * 100, axis=0)
13471348
return lower_percentile_point, upper_percentile_point
13481349

1349-
def conf_int_point(self, *, alpha=.1, tol=.001):
1350+
def conf_int_point(self, *, alpha=None, tol=None):
13501351
"""
13511352
Get the confidence interval of the point estimate of each treatment on each outcome for sample X.
13521353
@@ -1389,7 +1390,7 @@ def stderr_point(self):
13891390
"""
13901391
return np.sqrt(self.stderr_mean**2 + self.std_point**2)
13911392

1392-
def summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None):
1393+
def summary(self, alpha=None, value=None, decimals=None, tol=None, output_names=None, treatment_names=None):
13931394
"""
13941395
Output the summary inferences above.
13951396

econml/tests/test_inference.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,18 @@ def test_can_summarize(self):
288288
inference=BootstrapInference(5)
289289
).summary(1)
290290

291+
def test_alpha(self):
292+
Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W
293+
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression())
294+
est.fit(Y, T, X=X, W=W)
295+
296+
# ensure alpha is passed
297+
lb, ub = est.const_marginal_ate_interval(X, alpha=1)
298+
assert (lb == ub).all()
299+
300+
lb, ub = est.const_marginal_ate_interval(X)
301+
assert (lb != ub).all()
302+
291303
def test_inference_with_none_stderr(self):
292304
Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W
293305
est = DML(model_y=LinearRegression(),

0 commit comments

Comments
 (0)