Skip to content

Commit c17cbdb

Browse files
committed
Misimplemented e val. Re implemented
1 parent bfa7440 commit c17cbdb

File tree

2 files changed

+32
-39
lines changed

2 files changed

+32
-39
lines changed

causal_testing/testing/validation.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,36 @@ def estimate_robustness(self, model: RegressionResultsWrapper, q=1, alpha=1):
2424

2525
return rv
2626

27-
def estimate_e_value(
28-
self, risk_ratio, confidence_intervals: tuple[float, float]
29-
) -> tuple[float, tuple[float, float]]:
27+
def estimate_e_value(self, risk_ratio: float) -> float:
3028
"""Calculate the E value from a risk ratio. This allow
3129
the user to identify how large a risk an unidentified confounding
3230
variable would need to be to nullify the causal relationship
3331
under test."""
3432

3533
if risk_ratio >= 1:
36-
e = risk_ratio + math.sqrt(risk_ratio * (risk_ratio - 1))
34+
return risk_ratio + math.sqrt(risk_ratio * (risk_ratio - 1))
3735

38-
lower_limit = confidence_intervals[0]
39-
if lower_limit <= 1:
40-
lower_limit = 1
41-
else:
42-
lower_limit = lower_limit + math.sqrt(lower_limit * (lower_limit - 1))
36+
risk_ratio_prime = 1 / risk_ratio
37+
return risk_ratio_prime + math.sqrt(risk_ratio_prime * (risk_ratio_prime - 1))
4338

44-
return (e, (lower_limit, 1))
39+
def estimate_e_value_using_ci(self, risk_ratio: float, confidence_intervals: tuple[float, float]) -> float:
40+
"""Calculate the E value from a risk ratio and it's confidence intervals.
41+
This allow the user to identify how large a risk an unidentified
42+
confounding variable would need to be to nullify the causal relationship
43+
under test."""
4544

46-
risk_ratio_prime = 1 / risk_ratio
47-
e = risk_ratio_prime + math.sqrt(risk_ratio_prime * (risk_ratio_prime - 1))
45+
if risk_ratio >= 1:
46+
lower_limit = confidence_intervals[0]
47+
e = 1
48+
if lower_limit > 1:
49+
e = lower_limit + math.sqrt(lower_limit * (lower_limit - 1))
50+
51+
return e
4852

4953
upper_limit = confidence_intervals[1]
50-
if upper_limit >= 1:
51-
upper_limit = 1
52-
else:
54+
e = 1
55+
if upper_limit < 1:
5356
upper_limit_prime = 1 / upper_limit
54-
upper_limit = upper_limit_prime + math.sqrt(upper_limit_prime * (upper_limit_prime - 1))
57+
e = upper_limit_prime + math.sqrt(upper_limit_prime * (upper_limit_prime - 1))
5558

56-
return (e, (1, upper_limit))
59+
return e

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -179,31 +179,21 @@ def test_someEffect_fail(self):
179179
)
180180

181181
def test_positive_risk_ratio_e_value(self):
182-
test_value = TestValue("risk_ratio", 1.5)
183-
ctr = CausalTestResult(
184-
estimator=self.estimator,
185-
test_value=test_value,
186-
confidence_intervals=[1.2, 1.8],
187-
effect_modifier_configuration=None,
188-
)
189-
190182
cv = CausalValidator()
191-
e_value, e_confidence_intervals = cv.estimate_e_value(ctr.test_value.value, ctr.confidence_intervals)
183+
e_value = cv.estimate_e_value(1.5)
192184
self.assertEqual(round(e_value, 4), 2.366)
193-
self.assertEqual(round(e_confidence_intervals[0], 4), 1.6899)
194-
self.assertEqual(e_confidence_intervals[1], 1)
195185

196-
def test_negative_risk_ratio_e_value(self):
197-
test_value = TestValue("risk_ratio", 0.8)
198-
ctr = CausalTestResult(
199-
estimator=self.estimator,
200-
test_value=test_value,
201-
confidence_intervals=[0.2, 0.9],
202-
effect_modifier_configuration=None,
203-
)
186+
def test_positive_risk_ratio_e_value_using_ci(self):
187+
cv = CausalValidator()
188+
e_value = cv.estimate_e_value_using_ci(1.5, [1.2, 1.8])
189+
self.assertEqual(round(e_value, 4), 1.6899)
204190

191+
def test_negative_risk_ratio_e_value(self):
205192
cv = CausalValidator()
206-
e_value, e_confidence_intervals = cv.estimate_e_value(ctr.test_value.value, ctr.confidence_intervals)
193+
e_value = cv.estimate_e_value(0.8)
207194
self.assertEqual(round(e_value, 4), 1.809)
208-
self.assertEqual(e_confidence_intervals[0], 1)
209-
self.assertEqual(round(e_confidence_intervals[1], 4), 1.4625)
195+
196+
def test_negative_risk_ratio_e_value_using_ci(self):
197+
cv = CausalValidator()
198+
e_value = cv.estimate_e_value_using_ci(0.8, [0.2, 0.9])
199+
self.assertEqual(round(e_value, 4), 1.4625)

0 commit comments

Comments
 (0)