Skip to content

Commit a187fec

Browse files
authored
Merge pull request #169 from CITCOM-project/ci_pos_neg
Confidence intervals through inheritence
2 parents 0d53378 + b066d10 commit a187fec

File tree

3 files changed

+125
-45
lines changed

3 files changed

+125
-45
lines changed

causal_testing/testing/causal_test_outcome.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,29 @@ def __str__(self) -> str:
2222
return type(self).__name__
2323

2424

25-
class ExactValue(CausalTestOutcome):
25+
class SomeEffect(CausalTestOutcome):
26+
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
27+
28+
def apply(self, res: CausalTestResult) -> bool:
29+
if res.test_value.type == "ate":
30+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
31+
if res.test_value.type == "risk_ratio":
32+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
33+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
34+
35+
36+
class NoEffect(CausalTestOutcome):
37+
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
38+
39+
def apply(self, res: CausalTestResult) -> bool:
40+
if res.test_value.type == "ate":
41+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
42+
if res.test_value.type == "risk_ratio":
43+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
44+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
45+
46+
47+
class ExactValue(SomeEffect):
2648
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
2749

2850
def __init__(self, value: float, tolerance: float = None):
@@ -33,6 +55,8 @@ def __init__(self, value: float, tolerance: float = None):
3355
self.tolerance = tolerance
3456

3557
def apply(self, res: CausalTestResult) -> bool:
58+
if res.ci_valid():
59+
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.tolerance)
3660
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
3761

3862
def __str__(self):
@@ -43,6 +67,8 @@ class Positive(CausalTestOutcome):
4367
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
4468

4569
def apply(self, res: CausalTestResult) -> bool:
70+
if res.ci_valid() and not super().apply(res):
71+
return False
4672
if res.test_value.type == "ate":
4773
return res.test_value.value > 0
4874
if res.test_value.type == "risk_ratio":
@@ -54,36 +80,10 @@ class Negative(CausalTestOutcome):
5480
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
5581

5682
def apply(self, res: CausalTestResult) -> bool:
83+
if res.ci_valid() and not super().apply(res):
84+
return False
5785
if res.test_value.type == "ate":
5886
return res.test_value.value < 0
5987
if res.test_value.type == "risk_ratio":
6088
return res.test_value.value < 1
6189
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
62-
63-
64-
class SomeEffect(CausalTestOutcome):
65-
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
66-
67-
def apply(self, res: CausalTestResult) -> bool:
68-
if res.test_value.type == "ate":
69-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
70-
if res.test_value.type == "risk_ratio":
71-
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
72-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
73-
74-
def __str__(self):
75-
return "Changed"
76-
77-
78-
class NoEffect(CausalTestOutcome):
79-
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
80-
81-
def apply(self, res: CausalTestResult) -> bool:
82-
if res.test_value.type == "ate":
83-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
84-
if res.test_value.type == "risk_ratio":
85-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
86-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
87-
88-
def __str__(self):
89-
return "Unchanged"

causal_testing/testing/causal_test_result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def ci_high(self):
8585
return max(self.confidence_intervals)
8686
return None
8787

88+
def ci_valid(self) -> bool:
89+
"""Return whether or not the result has valid confidence invervals"""
90+
return self.ci_low() and self.ci_high()
91+
8892
def summary(self):
8993
"""Summarise the causal test result as an intuitive sentence."""
9094
print(

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative
2+
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative, NoEffect
33
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
44
from causal_testing.testing.estimators import LinearRegressionEstimator
55
from causal_testing.testing.validation import CausalValidator
@@ -86,6 +86,17 @@ def test_Positive_fail(self):
8686
ev = Positive()
8787
self.assertFalse(ev.apply(ctr))
8888

89+
def test_Positive_fail_ci(self):
90+
test_value = TestValue(type="ate", value=0)
91+
ctr = CausalTestResult(
92+
estimator=self.estimator,
93+
test_value=test_value,
94+
confidence_intervals=[-1, 1],
95+
effect_modifier_configuration=None,
96+
)
97+
ev = Positive()
98+
self.assertFalse(ev.apply(ctr))
99+
89100
def test_Negative_pass(self):
90101
test_value = TestValue(type="ate", value=-5.05)
91102
ctr = CausalTestResult(
@@ -108,6 +119,17 @@ def test_Negative_fail(self):
108119
ev = Negative()
109120
self.assertFalse(ev.apply(ctr))
110121

122+
def test_Negative_fail_ci(self):
123+
test_value = TestValue(type="ate", value=0)
124+
ctr = CausalTestResult(
125+
estimator=self.estimator,
126+
test_value=test_value,
127+
confidence_intervals=[-1, 1],
128+
effect_modifier_configuration=None,
129+
)
130+
ev = Negative()
131+
self.assertFalse(ev.apply(ctr))
132+
111133
def test_exactValue_pass(self):
112134
test_value = TestValue(type="ate", value=5.05)
113135
ctr = CausalTestResult(
@@ -119,6 +141,17 @@ def test_exactValue_pass(self):
119141
ev = ExactValue(5, 0.1)
120142
self.assertTrue(ev.apply(ctr))
121143

144+
def test_exactValue_pass_ci(self):
145+
test_value = TestValue(type="ate", value=5.05)
146+
ctr = CausalTestResult(
147+
estimator=self.estimator,
148+
test_value=test_value,
149+
confidence_intervals=[4, 6],
150+
effect_modifier_configuration=None,
151+
)
152+
ev = ExactValue(5, 0.1)
153+
self.assertTrue(ev.apply(ctr))
154+
122155
def test_exactValue_fail(self):
123156
test_value = TestValue(type="ate", value=0)
124157
ctr = CausalTestResult(
@@ -130,18 +163,52 @@ def test_exactValue_fail(self):
130163
ev = ExactValue(5, 0.1)
131164
self.assertFalse(ev.apply(ctr))
132165

133-
def test_someEffect_pass(self):
134-
test_value = TestValue(type="ate", value=5.05)
166+
def test_someEffect_invalid(self):
167+
test_value = TestValue(type="invalid", value=5.05)
135168
ctr = CausalTestResult(
136169
estimator=self.estimator,
137170
test_value=test_value,
138171
confidence_intervals=[4.8, 6.7],
139172
effect_modifier_configuration=None,
140173
)
141174
ev = SomeEffect()
142-
self.assertTrue(ev.apply(ctr))
175+
with self.assertRaises(ValueError):
176+
ev.apply(ctr)
177+
178+
def test_someEffect_pass_ate(self):
179+
test_value = TestValue(type="ate", value=5.05)
180+
ctr = CausalTestResult(
181+
estimator=self.estimator,
182+
test_value=test_value,
183+
confidence_intervals=[4.8, 6.7],
184+
effect_modifier_configuration=None,
185+
)
186+
self.assertTrue(SomeEffect().apply(ctr))
187+
self.assertFalse(NoEffect().apply(ctr))
188+
189+
def test_someEffect_pass_rr(self):
190+
test_value = TestValue(type="risk_ratio", value=5.05)
191+
ctr = CausalTestResult(
192+
estimator=self.estimator,
193+
test_value=test_value,
194+
confidence_intervals=[4.8, 6.7],
195+
effect_modifier_configuration=None,
196+
)
197+
self.assertTrue(SomeEffect().apply(ctr))
198+
self.assertFalse(NoEffect().apply(ctr))
143199

144200
def test_someEffect_fail(self):
201+
test_value = TestValue(type="ate", value=0)
202+
ctr = CausalTestResult(
203+
estimator=self.estimator,
204+
test_value=test_value,
205+
confidence_intervals=[-0.1, 0.2],
206+
effect_modifier_configuration=None,
207+
)
208+
self.assertFalse(SomeEffect().apply(ctr))
209+
self.assertTrue(NoEffect().apply(ctr))
210+
211+
def test_someEffect_str(self):
145212
test_value = TestValue(type="ate", value=0)
146213
ctr = CausalTestResult(
147214
estimator=self.estimator,
@@ -150,20 +217,29 @@ def test_someEffect_fail(self):
150217
effect_modifier_configuration=None,
151218
)
152219
ev = SomeEffect()
153-
self.assertFalse(ev.apply(ctr))
154220
self.assertEqual(
155-
str(ctr),
156-
(
157-
"Causal Test Result\n==============\n"
158-
"Treatment: A\n"
159-
"Control value: 0\n"
160-
"Treatment value: 1\n"
161-
"Outcome: A\n"
162-
"Adjustment set: set()\n"
163-
"ate: 0\n"
164-
"Confidence intervals: [-0.1, 0.2]\n"
165-
),
221+
ctr.to_dict(),
222+
{
223+
"treatment": "A",
224+
"control_value": 0,
225+
"treatment_value": 1,
226+
"outcome": "A",
227+
"adjustment_set": set(),
228+
"test_value": test_value,
229+
"ci_low": -0.1,
230+
"ci_high": 0.2,
231+
},
232+
)
233+
234+
def test_someEffect_dict(self):
235+
test_value = TestValue(type="ate", value=0)
236+
ctr = CausalTestResult(
237+
estimator=self.estimator,
238+
test_value=test_value,
239+
confidence_intervals=[-0.1, 0.2],
240+
effect_modifier_configuration=None,
166241
)
242+
ev = SomeEffect()
167243
self.assertEqual(
168244
ctr.to_dict(),
169245
{

0 commit comments

Comments
 (0)