Skip to content

Commit 7ef2a5c

Browse files
authored
Merge pull request #25 from CyberAgentAILab/feat/multiplier-var
Add multiplier confidence interval
2 parents 04a97db + 8ec79a6 commit 7ef2a5c

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

dte_adj/util.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def compute_confidence_intervals(
7070
omega / num_obs
7171
)
7272
return vec_dte_lower_moment, vec_dte_upper_moment
73-
elif variance_type == "uniform":
73+
elif variance_type in ["uniform", "multiplier"]:
7474
tstats = np.zeros((n_bootstrap, len(vec_loc)))
7575
boot_draw = np.zeros((n_bootstrap, len(vec_loc)))
7676

@@ -83,13 +83,25 @@ def compute_confidence_intervals(
8383
1 / num_obs * np.sum(xi[:, np.newaxis] * influence_function, axis=0)
8484
)
8585

86-
tstats = np.abs(boot_draw)[:, :-1] / np.sqrt(omega[:-1] / num_obs)
87-
max_tstats = np.max(tstats, axis=1)
88-
quantile_max_tstats = np.quantile(max_tstats, 1 - alpha)
86+
if variance_type == "uniform":
87+
tstats = np.abs(boot_draw)[:, :-1] / np.sqrt(omega[:-1] / num_obs)
88+
max_tstats = np.max(tstats, axis=1)
89+
quantile_max_tstats = np.quantile(max_tstats, 1 - alpha)
8990

90-
vec_dte_lower_boot = vec_dte - quantile_max_tstats * np.sqrt(omega / num_obs)
91-
vec_dte_upper_boot = vec_dte + quantile_max_tstats * np.sqrt(omega / num_obs)
92-
return vec_dte_lower_boot, vec_dte_upper_boot
91+
se = (
92+
np.quantile(boot_draw, 0.75, axis=0)
93+
- np.quantile(boot_draw, 0.25, axis=0)
94+
) / (norm.ppf(0.75) - norm.ppf(0.25))
95+
96+
vec_dte_lower_boot = vec_dte - quantile_max_tstats * se
97+
vec_dte_upper_boot = vec_dte + quantile_max_tstats * se
98+
return vec_dte_lower_boot, vec_dte_upper_boot
99+
else:
100+
se = np.std(boot_draw, axis=0)
101+
102+
vec_dte_lower_boot = vec_dte + se * norm.ppf(alpha / 2)
103+
vec_dte_upper_boot = vec_dte + se * norm.ppf(1 - alpha / 2)
104+
return vec_dte_lower_boot, vec_dte_upper_boot
93105
elif variance_type == "simple":
94106
w_target = num_obs / num_target
95107
w_control = num_obs / num_control

0 commit comments

Comments
 (0)