Skip to content

Commit 7c40163

Browse files
committed
Confidence intervals through inheritence
1 parent 6764fb3 commit 7c40163

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

causal_testing/testing/causal_test_outcome.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,35 @@ def __str__(self) -> str:
2222
return type(self).__name__
2323

2424

25-
class ExactValue(CausalTestOutcome):
25+
class SomeEffect(CausalTestOutcome):
26+
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
27+
28+
def apply(self, res: CausalTestResult) -> bool:
29+
if res.test_value.type == "ate":
30+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
31+
if res.test_value.type == "risk_ratio":
32+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
33+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
34+
35+
def __str__(self):
36+
return "Changed"
37+
38+
39+
class NoEffect(CausalTestOutcome):
40+
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
41+
42+
def apply(self, res: CausalTestResult) -> bool:
43+
if res.test_value.type == "ate":
44+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
45+
if res.test_value.type == "risk_ratio":
46+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
47+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
48+
49+
def __str__(self):
50+
return "Unchanged"
51+
52+
53+
class ExactValue(SomeEffect):
2654
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
2755

2856
def __init__(self, value: float, tolerance: float = None):
@@ -33,57 +61,32 @@ def __init__(self, value: float, tolerance: float = None):
3361
self.tolerance = tolerance
3462

3563
def apply(self, res: CausalTestResult) -> bool:
64+
super().apply()
3665
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
3766

3867
def __str__(self):
3968
return f"ExactValue: {self.value}±{self.tolerance}"
4069

4170

42-
class Positive(CausalTestOutcome):
71+
class Positive(SomeEffect):
4372
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
4473

4574
def apply(self, res: CausalTestResult) -> bool:
75+
super().apply()
4676
if res.test_value.type == "ate":
4777
return res.test_value.value > 0
4878
if res.test_value.type == "risk_ratio":
4979
return res.test_value.value > 1
5080
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
5181

5282

53-
class Negative(CausalTestOutcome):
83+
class Negative(SomeEffect):
5484
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
5585

5686
def apply(self, res: CausalTestResult) -> bool:
87+
super().apply()
5788
if res.test_value.type == "ate":
5889
return res.test_value.value < 0
5990
if res.test_value.type == "risk_ratio":
6091
return res.test_value.value < 1
6192
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
62-
63-
64-
class SomeEffect(CausalTestOutcome):
65-
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
66-
67-
def apply(self, res: CausalTestResult) -> bool:
68-
if res.test_value.type == "ate":
69-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
70-
if res.test_value.type == "risk_ratio":
71-
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
72-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
73-
74-
def __str__(self):
75-
return "Changed"
76-
77-
78-
class NoEffect(CausalTestOutcome):
79-
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
80-
81-
def apply(self, res: CausalTestResult) -> bool:
82-
if res.test_value.type == "ate":
83-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
84-
if res.test_value.type == "risk_ratio":
85-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
86-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
87-
88-
def __str__(self):
89-
return "Unchanged"

0 commit comments

Comments
 (0)