Skip to content

Commit 305a2e3

Browse files
committed
unify CI logic with dml framework
1 parent 2459d92 commit 305a2e3

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

doubleml/utils/blp.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import statsmodels.api as sm
66
from scipy.linalg import sqrtm
7-
from scipy.stats import norm, t
7+
from scipy.stats import norm
88

99
from ._estimation import _aggregate_coefs_and_ses
1010

@@ -161,16 +161,16 @@ def summary(self):
161161
if self.blp_model is None:
162162
df_summary = pd.DataFrame(columns=col_names)
163163
else:
164-
critical_value = t.ppf(0.975, self._blp_model[0].df_resid) if self._blp_model[0].use_t else norm.ppf(0.975)
164+
conf_int_values = [self._blp_model[i].conf_int() for i in range(self.n_rep)]
165165
t_values = np.divide(self.coef, self.se)
166166
p_values = 2 * norm.cdf(-np.abs(t_values))
167167
summary_stats = {
168168
"coef": self.coef,
169169
"std err": self.se,
170170
"t": t_values,
171171
"P>|t|": p_values,
172-
"[0.025": self.coef - critical_value * self.se,
173-
"0.975]": self.coef + critical_value * self.se,
172+
"[0.025": np.median([conf_int_values[i][0] for i in range(self.n_rep)], axis=0),
173+
"0.975]": np.median([conf_int_values[i][1] for i in range(self.n_rep)], axis=0),
174174
}
175175
df_summary = pd.DataFrame(summary_stats, columns=col_names, index=self._basis.columns)
176176
return df_summary
@@ -271,11 +271,9 @@ def confint(self, basis=None, joint=False, level=0.95, n_rep_boot=500):
271271
if joint:
272272
warnings.warn("Returning pointwise confidence intervals for basis coefficients.", UserWarning)
273273
# return the confidence intervals for the basis coefficients
274-
critical_value = (
275-
t.ppf(1 - alpha / 2, self._blp_model[0].df_resid) if self._blp_model[0].use_t else norm.ppf(1 - alpha / 2)
276-
)
277-
ci_lower = self.coef - critical_value * self.se
278-
ci_upper = self.coef + critical_value * self.se
274+
conf_int_values = [self._blp_model[i].conf_int(alpha=alpha) for i in range(self.n_rep)]
275+
ci_lower = np.median([conf_int_values[i][0] for i in range(self.n_rep)], axis=0)
276+
ci_upper = np.median([conf_int_values[i][1] for i in range(self.n_rep)], axis=0)
279277
ci = np.vstack((ci_lower, self.coef, ci_upper)).T
280278
df_ci = pd.DataFrame(
281279
ci,
@@ -292,31 +290,27 @@ def confint(self, basis=None, joint=False, level=0.95, n_rep_boot=500):
292290
raise ValueError("Invalid basis: DataFrame has to have the exact same number and ordering of columns.")
293291

294292
# blp of the orthogonal signal
295-
g_hat, blp_se, _, _ = self._predict_and_aggregate(basis)
293+
g_hat, _, all_g_hat, all_blp_se = self._predict_and_aggregate(basis)
296294

297295
if joint:
298296
np_basis = basis.to_numpy()
299-
bootstrap_samples = np.full((basis.shape[0], self.n_rep, n_rep_boot), np.nan)
297+
critical_values = np.full(self.n_rep, np.nan)
298+
300299
for i_rep in range(self.n_rep):
301300
normal_samples = np.random.normal(size=[basis.shape[1], n_rep_boot])
302-
omega_sqrt = np.real(sqrtm(self._blp_omega[:, :, i_rep]))
303-
bootstrap_samples[:, i_rep, :] = np.dot(np_basis, np.dot(omega_sqrt, normal_samples))
304-
305-
# aggregate the draws over repetitions according to the median aggregation rule
306-
bootstrap_samples = np.divide(np.median(bootstrap_samples, axis=1), blp_se.reshape(-1, 1))
307-
308-
max_t_stat = np.quantile(np.max(np.abs(bootstrap_samples), axis=1), q=level)
301+
omega_sqrt = sqrtm(self._blp_omega[:, :, i_rep])
302+
bootstrap_samples = np.multiply(
303+
np.dot(np_basis, np.dot(omega_sqrt, normal_samples)).T, (1.0 / all_blp_se[:, i_rep])
304+
)
305+
critical_values[i_rep] = np.quantile(np.max(np.abs(bootstrap_samples), axis=0), q=level)
306+
else:
307+
critical_values = np.repeat(norm.ppf(q=1 - alpha / 2), self.n_rep)
309308

310-
# Lower simultaneous CI
311-
g_hat_lower = g_hat - max_t_stat * blp_se
312-
# Upper simultaneous CI
313-
g_hat_upper = g_hat + max_t_stat * blp_se
309+
all_g_hat_lower = all_g_hat - critical_values * all_blp_se
310+
all_g_hat_upper = all_g_hat + critical_values * all_blp_se
314311

315-
else:
316-
# Lower point-wise CI
317-
g_hat_lower = g_hat + norm.ppf(q=alpha / 2) * blp_se
318-
# Upper point-wise CI
319-
g_hat_upper = g_hat + norm.ppf(q=1 - alpha / 2) * blp_se
312+
g_hat_lower = np.median(all_g_hat_lower, axis=1)
313+
g_hat_upper = np.median(all_g_hat_upper, axis=1)
320314

321315
ci = np.vstack((g_hat_lower, g_hat, g_hat_upper)).T
322316
df_ci = pd.DataFrame(

0 commit comments

Comments
 (0)