diff --git a/dte_adj/stratified.py b/dte_adj/stratified.py index e7e3d30..880f371 100644 --- a/dte_adj/stratified.py +++ b/dte_adj/stratified.py @@ -73,10 +73,9 @@ def _compute_cumulative_distribution( for s in s_list: s_mask = strata == s w_s[s] = (s_mask & treatment_mask).sum() / s_mask.sum() - for i, outcome in enumerate(locations): - for j in range(n_records): - s = strata[j] - prediction[j, i] = (outcomes[j] <= outcome) / w_s[s] * treatment_mask[j] + for j in range(n_records): + s = strata[j] + prediction[j] = (outcomes[j] <= locations) / w_s[s] * treatment_mask[j] unconditional_pred = {s: prediction[s == strata].mean(axis=0) for s in s_list} conditional_prediction = np.array([unconditional_pred[s] for s in strata]) diff --git a/dte_adj/util.py b/dte_adj/util.py index c0265b3..6b52f13 100644 --- a/dte_adj/util.py +++ b/dte_adj/util.py @@ -1,6 +1,12 @@ import numpy as np from scipy.stats import norm -from typing import Tuple +from typing import Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from dte_adj.local import ( + SimpleStratifiedDistributionEstimator, + AdjustedLocalDistributionEstimator, + ) def compute_confidence_intervals( @@ -110,7 +116,7 @@ def compute_confidence_intervals( def _compute_local_treatment_effects_core( - estimator, + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", target_treatment_arm: int, control_treatment_arm: int, locations: np.ndarray, @@ -149,10 +155,10 @@ def _compute_local_treatment_effects_core( # Compute treatment propensity (probability of treatment) d_t_prediction, d_t_psi, d_t_eta = estimator._compute_cumulative_distribution( - target_treatment_arm, np.zeros(1), X, Z, 1 - D + target_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D) ) d_c_prediction, d_c_psi, d_c_eta = estimator._compute_cumulative_distribution( - control_treatment_arm, np.zeros(1), X, Z, 1 - D + control_treatment_arm, np.zeros(1), X, Z, 1 - (target_treatment_arm == D) ) # Compute outcome distributions (different for LDTE vs LPTE) @@ -257,7 +263,7 @@ def xi(s): def compute_ldte( - estimator, + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", target_treatment_arm: int, control_treatment_arm: int, locations: np.ndarray, @@ -290,7 +296,7 @@ def compute_ldte( def compute_lpte( - estimator, + estimator: "SimpleStratifiedDistributionEstimator | AdjustedLocalDistributionEstimator", target_treatment_arm: int, control_treatment_arm: int, locations: np.ndarray, diff --git a/example/example.ipynb b/example/example.ipynb index e81d582..a47b590 100644 --- a/example/example.ipynb +++ b/example/example.ipynb @@ -278,9 +278,7 @@ ], "source": [ "pte, lower_bound, upper_bound = estimator.predict_pte(\n", - " target_treatment_arm=1,\n", - " control_treatment_arm=0,\n", - " locations=locations\n", + " target_treatment_arm=1, control_treatment_arm=0, locations=locations\n", ")\n", "plot(\n", " locations[:-1],\n",