1
1
# pylint: disable=too-few-public-methods
2
- """This module contains the CausalTestOutcome abstract class, as well as the concrete extension classes:
2
+ """This module contains the CausalEffect abstract class, as well as the concrete extension classes:
3
3
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
4
4
5
5
from abc import ABC , abstractmethod
9
9
from causal_testing .testing .causal_test_result import CausalTestResult
10
10
11
11
12
- class CausalTestOutcome (ABC ):
12
+ class CausalEffect (ABC ):
13
13
"""An abstract class representing an expected causal effect."""
14
14
15
15
@abstractmethod
@@ -23,8 +23,8 @@ def __str__(self) -> str:
23
23
return type (self ).__name__
24
24
25
25
26
- class SomeEffect (CausalTestOutcome ):
27
- """An extension of TestOutcome representing that the expected causal effect should not be zero."""
26
+ class SomeEffect (CausalEffect ):
27
+ """An extension of CausalEffect representing that the expected causal effect should not be zero."""
28
28
29
29
def apply (self , res : CausalTestResult ) -> bool :
30
30
if res .ci_low () is None or res .ci_high () is None :
@@ -38,11 +38,11 @@ def apply(self, res: CausalTestResult) -> bool:
38
38
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low , ci_high in zip (res .ci_low (), res .ci_high ())
39
39
)
40
40
41
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome " )
41
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this CausalEffect " )
42
42
43
43
44
- class NoEffect (CausalTestOutcome ):
45
- """An extension of TestOutcome representing that the expected causal effect should be zero."""
44
+ class NoEffect (CausalEffect ):
45
+ """An extension of CausalEffect representing that the expected causal effect should be zero."""
46
46
47
47
def __init__ (self , atol : float = 1e-10 , ctol : float = 0.05 ):
48
48
"""
@@ -70,58 +70,69 @@ def apply(self, res: CausalTestResult) -> bool:
70
70
< self .ctol
71
71
)
72
72
73
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome " )
73
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this CausalEffect " )
74
74
75
75
76
- class ExactValue (SomeEffect ):
77
- """An extension of TestOutcome representing that the expected causal effect should be a specific value."""
76
+ class ExactValue (CausalEffect ):
77
+ """An extension of CausalEffect representing that the expected causal effect should be a specific value."""
78
+
79
+ def __init__ (self , value : float , atol : float = None , ci_low : float = None , ci_high : float = None ):
80
+ if (ci_low is not None ) ^ (ci_high is not None ):
81
+ raise ValueError ("If specifying confidence intervals, must specify `ci_low` and `ci_high` parameters." )
82
+ if atol is not None and atol < 0 :
83
+ raise ValueError ("Tolerance must be an absolute (positive) value." )
78
84
79
- def __init__ (self , value : float , atol : float = None ):
80
85
self .value = value
81
- if atol is None :
82
- self .atol = abs (value * 0.05 )
83
- else :
84
- self .atol = atol
85
- if self .atol < 0 :
86
- raise ValueError ("Tolerance must be an absolute value." )
86
+ self .ci_low = ci_low
87
+ self .ci_high = ci_high
88
+ self .atol = atol if atol is not None else abs (value * 0.05 )
89
+
90
+ if self .ci_low is not None and self .ci_high is not None :
91
+ if not self .ci_low <= self .value <= self .ci_high :
92
+ raise ValueError ("Specified value falls outside the specified confidence intervals." )
93
+ if self .value - self .atol < self .ci_low or self .value + self .atol > self .ci_high :
94
+ raise ValueError (
95
+ "Arithmetic tolerance falls outside the confidence intervals."
96
+ "Try specifying a smaller value of atol."
97
+ )
87
98
88
99
def apply (self , res : CausalTestResult ) -> bool :
89
- if res .ci_valid ():
90
- return super ().apply (res ) and np .isclose (res .test_value .value , self .value , atol = self .atol )
91
- return np .isclose (res .test_value .value , self .value , atol = self .atol )
100
+ close = np .isclose (res .test_value .value , self .value , atol = self .atol )
101
+ if res .ci_valid () and self .ci_low is not None and self .ci_high is not None :
102
+ return all (
103
+ close and self .ci_low <= ci_low and self .ci_high >= ci_high
104
+ for ci_low , ci_high in zip (res .ci_low (), res .ci_high ())
105
+ )
106
+ return close
92
107
93
108
def __str__ (self ):
94
109
return f"ExactValue: { self .value } ±{ self .atol } "
95
110
96
111
97
112
class Positive (SomeEffect ):
98
- """An extension of TestOutcome representing that the expected causal effect should be positive.
113
+ """An extension of CausalEffect representing that the expected causal effect should be positive.
99
114
Currently only single values are supported for the test value"""
100
115
101
116
def apply (self , res : CausalTestResult ) -> bool :
102
- if res .ci_valid () and not super ().apply (res ):
103
- return False
104
117
if len (res .test_value .value ) > 1 :
105
118
raise ValueError ("Positive Effects are currently only supported on single float datatypes" )
106
119
if res .test_value .type in {"ate" , "coefficient" }:
107
120
return bool (res .test_value .value [0 ] > 0 )
108
121
if res .test_value .type == "risk_ratio" :
109
122
return bool (res .test_value .value [0 ] > 1 )
110
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome " )
123
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this CausalEffect " )
111
124
112
125
113
126
class Negative (SomeEffect ):
114
- """An extension of TestOutcome representing that the expected causal effect should be negative.
127
+ """An extension of CausalEffect representing that the expected causal effect should be negative.
115
128
Currently only single values are supported for the test value"""
116
129
117
130
def apply (self , res : CausalTestResult ) -> bool :
118
- if res .ci_valid () and not super ().apply (res ):
119
- return False
120
131
if len (res .test_value .value ) > 1 :
121
132
raise ValueError ("Negative Effects are currently only supported on single float datatypes" )
122
133
if res .test_value .type in {"ate" , "coefficient" }:
123
134
return bool (res .test_value .value [0 ] < 0 )
124
135
if res .test_value .type == "risk_ratio" :
125
136
return bool (res .test_value .value [0 ] < 1 )
126
137
# Dead code but necessary for pylint
127
- raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome " )
138
+ raise ValueError (f"Test Value type { res .test_value .type } is not valid for this CausalEffect " )
0 commit comments