|
1 | 1 | import numpy as np |
2 | 2 | from scipy.stats import norm |
3 | | -from typing import Tuple |
| 3 | +from typing import Tuple, TYPE_CHECKING |
| 4 | + |
| 5 | +if TYPE_CHECKING: |
| 6 | + from dte_adj.local import ( |
| 7 | + SimpleStratifiedDistributionEstimator, |
| 8 | + AdjustedLocalDistributionEstimator, |
| 9 | + ) |
4 | 10 |
|
5 | 11 |
|
6 | 12 | def compute_confidence_intervals( |
@@ -110,7 +116,7 @@ def compute_confidence_intervals( |
110 | 116 |
|
111 | 117 |
|
112 | 118 | def _compute_local_treatment_effects_core( |
113 | | - estimator, |
| 119 | + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", |
114 | 120 | target_treatment_arm: int, |
115 | 121 | control_treatment_arm: int, |
116 | 122 | locations: np.ndarray, |
@@ -149,10 +155,10 @@ def _compute_local_treatment_effects_core( |
149 | 155 |
|
150 | 156 | # Compute treatment propensity (probability of treatment) |
151 | 157 | d_t_prediction, d_t_psi, d_t_eta = estimator._compute_cumulative_distribution( |
152 | | - target_treatment_arm, np.zeros(1), X, Z, 1 - D |
| 158 | + target_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D) |
153 | 159 | ) |
154 | 160 | d_c_prediction, d_c_psi, d_c_eta = estimator._compute_cumulative_distribution( |
155 | | - control_treatment_arm, np.zeros(1), X, Z, 1 - D |
| 161 | + control_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D) |
156 | 162 | ) |
157 | 163 |
|
158 | 164 | # Compute outcome distributions (different for LDTE vs LPTE) |
@@ -257,7 +263,7 @@ def xi(s): |
257 | 263 |
|
258 | 264 |
|
259 | 265 | def compute_ldte( |
260 | | - estimator, |
| 266 | + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", |
261 | 267 | target_treatment_arm: int, |
262 | 268 | control_treatment_arm: int, |
263 | 269 | locations: np.ndarray, |
@@ -290,7 +296,7 @@ def compute_ldte( |
290 | 296 |
|
291 | 297 |
|
292 | 298 | def compute_lpte( |
293 | | - estimator, |
| 299 | + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", |
294 | 300 | target_treatment_arm: int, |
295 | 301 | control_treatment_arm: int, |
296 | 302 | locations: np.ndarray, |
|
0 commit comments