Skip to content

Commit d283bf7

Browse files
authored
Support target treatment indicator other than 1 (#78)
* support target treatment indicator other than 1 * revert
1 parent 526a383 commit d283bf7

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

dte_adj/stratified.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ def _compute_cumulative_distribution(
7373
for s in s_list:
7474
s_mask = strata == s
7575
w_s[s] = (s_mask & treatment_mask).sum() / s_mask.sum()
76-
for i, outcome in enumerate(locations):
77-
for j in range(n_records):
78-
s = strata[j]
79-
prediction[j, i] = (outcomes[j] <= outcome) / w_s[s] * treatment_mask[j]
76+
for j in range(n_records):
77+
s = strata[j]
78+
prediction[j] = (outcomes[j] <= locations) / w_s[s] * treatment_mask[j]
8079

8180
unconditional_pred = {s: prediction[s == strata].mean(axis=0) for s in s_list}
8281
conditional_prediction = np.array([unconditional_pred[s] for s in strata])

dte_adj/util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import numpy as np
22
from scipy.stats import norm
3-
from typing import Tuple
3+
from typing import Tuple, TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from dte_adj.local import (
7+
SimpleStratifiedDistributionEstimator,
8+
AdjustedLocalDistributionEstimator,
9+
)
410

511

612
def compute_confidence_intervals(
@@ -110,7 +116,7 @@ def compute_confidence_intervals(
110116

111117

112118
def _compute_local_treatment_effects_core(
113-
estimator,
119+
estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator",
114120
target_treatment_arm: int,
115121
control_treatment_arm: int,
116122
locations: np.ndarray,
@@ -149,10 +155,10 @@ def _compute_local_treatment_effects_core(
149155

150156
# Compute treatment propensity (probability of treatment)
151157
d_t_prediction, d_t_psi, d_t_eta = estimator._compute_cumulative_distribution(
152-
target_treatment_arm, np.zeros(1), X, Z, 1 - D
158+
target_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D)
153159
)
154160
d_c_prediction, d_c_psi, d_c_eta = estimator._compute_cumulative_distribution(
155-
control_treatment_arm, np.zeros(1), X, Z, 1 - D
161+
control_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D)
156162
)
157163

158164
# Compute outcome distributions (different for LDTE vs LPTE)
@@ -257,7 +263,7 @@ def xi(s):
257263

258264

259265
def compute_ldte(
260-
estimator,
266+
estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator",
261267
target_treatment_arm: int,
262268
control_treatment_arm: int,
263269
locations: np.ndarray,
@@ -290,7 +296,7 @@ def compute_ldte(
290296

291297

292298
def compute_lpte(
293-
estimator,
299+
estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator",
294300
target_treatment_arm: int,
295301
control_treatment_arm: int,
296302
locations: np.ndarray,

example/example.ipynb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,7 @@
278278
],
279279
"source": [
280280
"pte, lower_bound, upper_bound = estimator.predict_pte(\n",
281-
" target_treatment_arm=1,\n",
282-
" control_treatment_arm=0,\n",
283-
" locations=locations\n",
281+
" target_treatment_arm=1, control_treatment_arm=0, locations=locations\n",
284282
")\n",
285283
"plot(\n",
286284
" locations[:-1],\n",

0 commit comments

Comments
 (0)