Skip to content

Commit 8d77dea

Browse files
authored
Merge pull request #116 from CITCOM-project/logistic-regression-confidence-intervals
Added confidence intervals to the logistic regression estimator
2 parents adecf3d + b84b76c commit 8d77dea

File tree

12 files changed

+459
-355
lines changed

12 files changed

+459
-355
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
176176

177177
result_string = str()
178178
if causal_test_result.ci_low() and causal_test_result.ci_high():
179-
result_string = (
180-
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
181-
)
179+
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
182180
else:
183181
result_string = causal_test_result.test_value.value
184182
if f_flag:

causal_testing/testing/estimators.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def _run_logistic_regression(self) -> RegressionResultsWrapper:
149149
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
150150
"""Estimate the outcomes under control and treatment.
151151
152-
:return: The average treatment effect and the 95% Wald confidence intervals.
152+
:return: The estimated control and treatment values and their confidence
153+
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
153154
"""
154155
model = self._run_logistic_regression()
155156
self.model = model
@@ -168,31 +169,53 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
168169
x = x[model.params.index]
169170

170171
y = model.predict(x)
171-
return y.iloc[1], y.iloc[0]
172+
173+
# Delta method confidence intervals from
174+
# https://stackoverflow.com/questions/47414842/confidence-interval-of-probability-prediction-from-logistic-regression-statsmode
175+
cov = model.cov_params()
176+
gradient = (y * (1 - y) * x.T).T # matrix of gradients for each observation
177+
std_errors = np.array([np.sqrt(np.dot(np.dot(g, cov), g)) for g in gradient.to_numpy()])
178+
c = 1.96 # multiplier for confidence interval
179+
upper = np.maximum(0, np.minimum(1, y + std_errors * c))
180+
lower = np.maximum(0, np.minimum(1, y - std_errors * c))
181+
182+
return (lower.iloc[1], y.iloc[1], upper.iloc[1]), (lower.iloc[0], y.iloc[0], upper.iloc[0])
172183

173184
def estimate_ate(self) -> float:
174185
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
175186
by changing the treatment variable from the control value to the treatment value. Here, we actually
176187
calculate the expected outcomes under control and treatment and take one away from the other. This
177188
allows for custom terms to be put in such as squares, inverses, products, etc.
178189
179-
:return: The average treatment effect. Confidence intervals are not yet supported.
190+
:return: The estimated average treatment effect and 95% confidence intervals
180191
"""
181-
control_outcome, treatment_outcome = self.estimate_control_treatment()
192+
(cci_low, control_outcome, cci_high), (tci_low, treatment_outcome, tci_high) = self.estimate_control_treatment()
193+
194+
ci_low = tci_low - cci_high
195+
ci_high = tci_high - cci_low
196+
estimate = treatment_outcome - control_outcome
182197

183-
return treatment_outcome - control_outcome
198+
logger.info(
199+
f"Changing {self.treatment} from {self.control_values} to {self.treatment_values} gives an estimated ATE of {ci_low} < {estimate} < {ci_high}"
200+
)
201+
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
202+
203+
return estimate, (ci_low, ci_high)
184204

185205
def estimate_risk_ratio(self) -> float:
186206
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
187207
by changing the treatment variable from the control value to the treatment value. Here, we actually
188208
calculate the expected outcomes under control and treatment and divide one by the other. This
189209
allows for custom terms to be put in such as squares, inverses, products, etc.
190210
191-
:return: The average treatment effect. Confidence intervals are not yet supported.
211+
:return: The estimated risk ratio and 95% confidence intervals.
192212
"""
193-
control_outcome, treatment_outcome = self.estimate_control_treatment()
213+
(cci_low, control_outcome, cci_high), (tci_low, treatment_outcome, tci_high) = self.estimate_control_treatment()
194214

195-
return treatment_outcome / control_outcome
215+
ci_low = tci_low / cci_high
216+
ci_high = tci_high / cci_low
217+
218+
return treatment_outcome / control_outcome, (ci_low, ci_high)
196219

197220
def estimate_unit_odds_ratio(self) -> float:
198221
"""Estimate the odds ratio of increasing the treatment by one. In logistic regression, this corresponds to the
@@ -214,7 +237,7 @@ def __init__(
214237
treatment: tuple,
215238
treatment_values: float,
216239
control_values: float,
217-
adjustment_set: set,
240+
adjustment_set: list[float],
218241
outcome: tuple,
219242
df: pd.DataFrame = None,
220243
effect_modifiers: dict[Variable:Any] = None,
@@ -332,7 +355,8 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
332355
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
333356
"""Estimate the outcomes under control and treatment.
334357
335-
:return: The average treatment effect and the 95% Wald confidence intervals.
358+
:return: The estimated outcome under control and treatment in the form
359+
(control_outcome, treatment_outcome).
336360
"""
337361
model = self._run_linear_regression()
338362
self.model = model

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010

1111
class TestObservationalDataCollector(unittest.TestCase):
12-
1312
def setUp(self) -> None:
1413
temp_dir_path = create_temp_dir_if_non_existent()
1514
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
1615
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
1716
# Y = 3*X1 + X2*X3 + 10
1817
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
1918
self.observational_df["Y"] = self.observational_df.apply(
20-
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1)
19+
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1
20+
)
2121
self.observational_df.to_csv(self.observational_df_path)
2222
self.X1 = Input("X1", int, uniform(1, 4))
2323
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
@@ -45,12 +45,13 @@ def test_data_constraints(self):
4545

4646
def test_meta_population(self):
4747
def populate_m(data):
48-
data['M'] = data['X1'] * 2
48+
data["M"] = data["X1"] * 2
49+
4950
meta = Meta("M", int, populate_m)
5051
scenario = Scenario({self.X1, meta})
5152
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
5253
data = observational_data_collector.collect_data()
53-
assert all((m == 2*x1 for x1, m in zip(data['X1'], data['M'])))
54+
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
5455

5556
def tearDown(self) -> None:
5657
remove_temp_dir_if_existent()

tests/generation_tests/test_abstract_test_case.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@
88
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
99
from causal_testing.testing.causal_test_outcome import Positive
1010

11+
1112
class TestAbstractTestCase(unittest.TestCase):
1213
"""
13-
Class to test abstract test cases.
14+
Class to test abstract test cases.
1415
"""
16+
1517
def setUp(self) -> None:
1618
temp_dir_path = create_temp_dir_if_non_existent()
1719
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
1820
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
1921
# Y = 3*X1 + X2*X3 + 10
2022
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
21-
self.observational_df["Y"] = self.observational_df.apply(lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10,
22-
axis=1)
23+
self.observational_df["Y"] = self.observational_df.apply(
24+
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1
25+
)
2326
self.observational_df.to_csv(self.observational_df_path)
2427
self.X1 = Input("X1", float, uniform(1, 4))
2528
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
@@ -51,9 +54,9 @@ def test_str(self):
5154
expected_causal_effect={self.Y: Positive()},
5255
effect_modifiers=None,
5356
)
54-
assert str(abstract) == \
55-
"When we apply intervention {X1' > X1}, the effect on Output: Y::int should be Positive", \
56-
f"Unexpected string {str(abstract)}"
57+
assert (
58+
str(abstract) == "When we apply intervention {X1' > X1}, the effect on Output: Y::int should be Positive"
59+
), f"Unexpected string {str(abstract)}"
5760

5861
def test_datapath(self):
5962
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
@@ -109,7 +112,6 @@ def test_generate_concrete_test_cases_rct(self):
109112
assert len(concrete_tests) == 2, "Expected 2 concrete tests"
110113
assert len(runs) == 4, "Expected 4 runs"
111114

112-
113115
def test_infeasible_constraints(self):
114116
scenario = Scenario({self.X1, self.X2, self.X3, self.X4}, [self.X1.z3 > 2])
115117
scenario.setup_treatment_variables()
@@ -125,10 +127,9 @@ def test_infeasible_constraints(self):
125127

126128
with self.assertWarns(Warning):
127129
concrete_tests, runs = abstract.generate_concrete_tests(4, rct=True, target_ks_score=0.1, hard_max=HARD_MAX)
128-
self.assertTrue(all((x > 2 for x in runs['X1'])))
130+
self.assertTrue(all((x > 2 for x in runs["X1"])))
129131
self.assertEqual(len(concrete_tests), HARD_MAX * NUM_STRATA)
130132

131-
132133
def test_feasible_constraints(self):
133134
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
134135
scenario.setup_treatment_variables()
@@ -142,7 +143,6 @@ def test_feasible_constraints(self):
142143
concrete_tests, _ = abstract.generate_concrete_tests(4, rct=True, target_ks_score=0.1, hard_max=1000)
143144
assert len(concrete_tests) < 1000
144145

145-
146146
def tearDown(self) -> None:
147147
remove_temp_dir_if_existent()
148148

tests/json_front_tests/test_json_class.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,10 @@ def test_generate_tests_from_json(self):
9393
self.json_class.test_plan = example_test
9494
effects = {"NoEffect": NoEffect()}
9595
mutates = {
96-
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3 >
97-
self.json_class.modelling_scenario.variables[x].z3
98-
}
99-
estimators = {
100-
"LinearRegressionEstimator": LinearRegressionEstimator
96+
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3
97+
> self.json_class.modelling_scenario.variables[x].z3
10198
}
99+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
102100

103101
with self.assertLogs() as captured:
104102
self.json_class.generate_tests(effects, mutates, estimators, False)
@@ -108,9 +106,8 @@ def test_generate_tests_from_json(self):
108106

109107
def tearDown(self) -> None:
110108
pass
111-
#remove_temp_dir_if_existent()
109+
# remove_temp_dir_if_existent()
112110

113111

114112
def populate_example(*args, **kwargs):
115113
pass
116-

0 commit comments

Comments
 (0)