Skip to content

Commit 0d53378

Browse files
authored
Merge pull request #113 from CITCOM-project/rsomers/robustness
Robustness Estimators
2 parents 6ee781d + c17cbdb commit 0d53378

File tree

4 files changed

+91
-0
lines changed

4 files changed

+91
-0
lines changed

causal_testing/testing/estimators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
355355
:return: The average treatment effect and the 95% Wald confidence intervals.
356356
"""
357357
model = self._run_linear_regression()
358+
self.model = model
358359

359360
# Create an empty individual for the control and treated
360361
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)

causal_testing/testing/validation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""This module contains the CausalValidator class for performing Quantitive Bias Analysis techniques"""
2+
import math
3+
import numpy as np
4+
from scipy.stats import t
5+
from statsmodels.regression.linear_model import RegressionResultsWrapper
6+
7+
8+
class CausalValidator:
9+
"""A suite of validation tools to perform Quantitive Bias Analysis to back up causal claims"""
10+
11+
def estimate_robustness(self, model: RegressionResultsWrapper, q=1, alpha=1):
12+
"""Calculate the robustness of a linear regression model. This allow
13+
the user to identify how large an unidentified confounding variable
14+
would need to be to nullify the causal relationship under test."""
15+
16+
dof = model.df_resid
17+
t_values = model.tvalues
18+
19+
fq = q * abs(t_values / math.sqrt(dof))
20+
f_crit = abs(t.ppf(alpha / 2, dof - 1)) / math.sqrt(dof - 1)
21+
fqa = fq - f_crit
22+
23+
rv = 0.5 * (np.sqrt(fqa**4 + (4 * fqa**2)) - fqa**2)
24+
25+
return rv
26+
27+
def estimate_e_value(self, risk_ratio: float) -> float:
28+
"""Calculate the E value from a risk ratio. This allow
29+
the user to identify how large a risk an unidentified confounding
30+
variable would need to be to nullify the causal relationship
31+
under test."""
32+
33+
if risk_ratio >= 1:
34+
return risk_ratio + math.sqrt(risk_ratio * (risk_ratio - 1))
35+
36+
risk_ratio_prime = 1 / risk_ratio
37+
return risk_ratio_prime + math.sqrt(risk_ratio_prime * (risk_ratio_prime - 1))
38+
39+
def estimate_e_value_using_ci(self, risk_ratio: float, confidence_intervals: tuple[float, float]) -> float:
40+
"""Calculate the E value from a risk ratio and it's confidence intervals.
41+
This allow the user to identify how large a risk an unidentified
42+
confounding variable would need to be to nullify the causal relationship
43+
under test."""
44+
45+
if risk_ratio >= 1:
46+
lower_limit = confidence_intervals[0]
47+
e = 1
48+
if lower_limit > 1:
49+
e = lower_limit + math.sqrt(lower_limit * (lower_limit - 1))
50+
51+
return e
52+
53+
upper_limit = confidence_intervals[1]
54+
e = 1
55+
if upper_limit < 1:
56+
upper_limit_prime = 1 / upper_limit
57+
e = upper_limit_prime + math.sqrt(upper_limit_prime * (upper_limit_prime - 1))
58+
59+
return e

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative
33
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
44
from causal_testing.testing.estimators import LinearRegressionEstimator
5+
from causal_testing.testing.validation import CausalValidator
56

67

78
class TestCausalTestOutcome(unittest.TestCase):
@@ -176,3 +177,23 @@ def test_someEffect_fail(self):
176177
"ci_high": 0.2,
177178
},
178179
)
180+
181+
def test_positive_risk_ratio_e_value(self):
182+
cv = CausalValidator()
183+
e_value = cv.estimate_e_value(1.5)
184+
self.assertEqual(round(e_value, 4), 2.366)
185+
186+
def test_positive_risk_ratio_e_value_using_ci(self):
187+
cv = CausalValidator()
188+
e_value = cv.estimate_e_value_using_ci(1.5, [1.2, 1.8])
189+
self.assertEqual(round(e_value, 4), 1.6899)
190+
191+
def test_negative_risk_ratio_e_value(self):
192+
cv = CausalValidator()
193+
e_value = cv.estimate_e_value(0.8)
194+
self.assertEqual(round(e_value, 4), 1.809)
195+
196+
def test_negative_risk_ratio_e_value_using_ci(self):
197+
cv = CausalValidator()
198+
e_value = cv.estimate_e_value_using_ci(0.8, [0.2, 0.9])
199+
self.assertEqual(round(e_value, 4), 1.4625)

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
InstrumentalVariableEstimator,
1010
)
1111
from causal_testing.specification.variable import Input
12+
from causal_testing.testing.validation import CausalValidator
1213

1314

1415
def plot_results_df(df):
@@ -372,6 +373,15 @@ def test_program_15_no_interaction_ate_calculated(self):
372373
self.assertEqual(round(ate, 1), 3.5)
373374
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
374375

376+
def test_program_11_2_with_robustness_validation(self):
377+
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
378+
df = self.chapter_11_df.copy()
379+
linear_regression_estimator = LinearRegressionEstimator("treatments", 100, 90, set(), "outcomes", df)
380+
model = linear_regression_estimator._run_linear_regression()
381+
382+
cv = CausalValidator()
383+
self.assertEqual(round(cv.estimate_robustness(model)["treatments"], 4), 0.7353)
384+
375385

376386
class TestCausalForestEstimator(unittest.TestCase):
377387
"""Test the linear regression estimator against the programming exercises in Section 2 of Hernán and Robins [1].

0 commit comments

Comments
 (0)