Skip to content

Commit ec4de83

Browse files
committed
Blacked
1 parent 38d18b7 commit ec4de83

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

causal_testing/testing/estimators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
173173
# Delta method confidence intervals from
174174
# https://stackoverflow.com/questions/47414842/confidence-interval-of-probability-prediction-from-logistic-regression-statsmode
175175
cov = model.cov_params()
176-
gradient = (y * (1 - y) * x.T).T # matrix of gradients for each observation
176+
gradient = (y * (1 - y) * x.T).T # matrix of gradients for each observation
177177
std_errors = np.array([np.sqrt(np.dot(np.dot(g, cov), g)) for g in gradient.to_numpy()])
178-
c = 1.96 # multiplier for confidence interval
178+
c = 1.96 # multiplier for confidence interval
179179
upper = np.maximum(0, np.minimum(1, y + std_errors * c))
180180
lower = np.maximum(0, np.minimum(1, y - std_errors * c))
181181

@@ -195,7 +195,9 @@ def estimate_ate(self) -> float:
195195
ci_high = tci_high - cci_low
196196
estimate = treatment_outcome - control_outcome
197197

198-
logger.info(f"Changing {self.treatment} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}")
198+
logger.info(
199+
f"Changing {self.treatment} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}"
200+
)
199201
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
200202

201203
return estimate, (ci_low, ci_high)

0 commit comments

Comments
 (0)