@@ -73,22 +73,24 @@ def apply(self, res: CausalTestResult) -> bool:
73
73
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
74
74
75
75
76
- class ExactValue (SomeEffect ):
76
+ class ExactValue (CausalTestOutcome ):
77
77
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
78
78
79
- def __init__ (self , value : float , atol : float = None ):
79
+ def __init__ (self , value : float , atol : float = None , ci_low : float = None , ci_high : float = None ):
80
+ if (ci_low is not None ) ^ (ci_high is not None ):
81
+ raise ValueError ("If specifying confidence intervals, must specify `ci_low` and `ci_high` parameters." )
80
82
self .value = value
81
- if atol is None :
82
- self .atol = abs (value * 0.05 )
83
- else :
84
- self .atol = atol
83
+ self .ci_low = ci_low
84
+ self .ci_high = ci_high
85
+ self .atol = atol if atol is not None else abs (value * 0.05 )
85
86
if self .atol < 0 :
86
- raise ValueError ("Tolerance must be an absolute value." )
87
+ raise ValueError ("Tolerance must be an absolute (positive) value." )
87
88
88
89
def apply (self , res : CausalTestResult ) -> bool :
89
- if res .ci_valid ():
90
- return super ().apply (res ) and np .isclose (res .test_value .value , self .value , atol = self .atol )
91
- return np .isclose (res .test_value .value , self .value , atol = self .atol )
90
+ close = np .isclose (res .test_value .value , self .value , atol = self .atol )
91
+ if res .ci_valid () and self .ci_low is not None and self .ci_high is not None :
92
+ return close and self .ci_low <= res .ci_low () and self .ci_high >= res .ci_high ()
93
+ return close
92
94
93
95
def __str__ (self ):
94
96
return f"ExactValue: { self .value } ±{ self .atol } "
0 commit comments