@@ -22,7 +22,35 @@ def __str__(self) -> str:
22
22
return type (self ).__name__
23
23
24
24
25
- class ExactValue (CausalTestOutcome ):
25
+ class SomeEffect (CausalTestOutcome ):
26
+ """An extension of TestOutcome representing that the expected causal effect should not be zero."""
27
+
28
+ def apply (self , res : CausalTestResult ) -> bool :
29
+ if res .test_value .type == "ate" :
30
+ return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
31
+ if res .test_value .type == "risk_ratio" :
32
+ return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
33
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
34
+
35
+ def __str__ (self ):
36
+ return "Changed"
37
+
38
+
39
+ class NoEffect (CausalTestOutcome ):
40
+ """An extension of TestOutcome representing that the expected causal effect should be zero."""
41
+
42
+ def apply (self , res : CausalTestResult ) -> bool :
43
+ if res .test_value .type == "ate" :
44
+ return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .test_value .value ) < 1e-10 )
45
+ if res .test_value .type == "risk_ratio" :
46
+ return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = 1e-10 )
47
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
48
+
49
+ def __str__ (self ):
50
+ return "Unchanged"
51
+
52
+
53
+ class ExactValue (SomeEffect ):
26
54
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
27
55
28
56
def __init__ (self , value : float , tolerance : float = None ):
@@ -33,57 +61,32 @@ def __init__(self, value: float, tolerance: float = None):
33
61
self .tolerance = tolerance
34
62
35
63
def apply (self , res : CausalTestResult ) -> bool :
64
+ super ().apply ()
36
65
return np .isclose (res .test_value .value , self .value , atol = self .tolerance )
37
66
38
67
def __str__ (self ):
39
68
return f"ExactValue: { self .value } ±{ self .tolerance } "
40
69
41
70
42
- class Positive (CausalTestOutcome ):
71
+ class Positive (SomeEffect ):
43
72
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
44
73
45
74
def apply (self , res : CausalTestResult ) -> bool :
75
+ super ().apply ()
46
76
if res .test_value .type == "ate" :
47
77
return res .test_value .value > 0
48
78
if res .test_value .type == "risk_ratio" :
49
79
return res .test_value .value > 1
50
80
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
51
81
52
82
53
- class Negative (CausalTestOutcome ):
83
+ class Negative (SomeEffect ):
54
84
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
55
85
56
86
def apply (self , res : CausalTestResult ) -> bool :
87
+ super ().apply ()
57
88
if res .test_value .type == "ate" :
58
89
return res .test_value .value < 0
59
90
if res .test_value .type == "risk_ratio" :
60
91
return res .test_value .value < 1
61
92
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
62
-
63
-
64
- class SomeEffect (CausalTestOutcome ):
65
- """An extension of TestOutcome representing that the expected causal effect should not be zero."""
66
-
67
- def apply (self , res : CausalTestResult ) -> bool :
68
- if res .test_value .type == "ate" :
69
- return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
70
- if res .test_value .type == "risk_ratio" :
71
- return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
72
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
73
-
74
- def __str__ (self ):
75
- return "Changed"
76
-
77
-
78
- class NoEffect (CausalTestOutcome ):
79
- """An extension of TestOutcome representing that the expected causal effect should be zero."""
80
-
81
- def apply (self , res : CausalTestResult ) -> bool :
82
- if res .test_value .type == "ate" :
83
- return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .test_value .value ) < 1e-10 )
84
- if res .test_value .type == "risk_ratio" :
85
- return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = 1e-10 )
86
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
87
-
88
- def __str__ (self ):
89
- return "Unchanged"
0 commit comments