3
3
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
4
4
5
5
from abc import ABC , abstractmethod
6
+ from collections .abc import Iterable
6
7
import numpy as np
7
8
8
9
from causal_testing .testing .causal_test_result import CausalTestResult
@@ -26,9 +27,12 @@ class SomeEffect(CausalTestOutcome):
26
27
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
27
28
28
29
def apply (self , res : CausalTestResult ) -> bool :
29
- if res .test_value .type in {"ate" , "coefficient" }:
30
- return any ([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low , ci_high in zip (res .ci_low (), res .ci_high ())])
31
- # return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
30
+ if res .test_value .type == "ate" :
31
+ return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
32
+ if res .test_value .type == "coefficient" :
33
+ ci_low = res .ci_low () if isinstance (res .ci_low (), Iterable ) else [res .ci_low ()]
34
+ ci_high = res .ci_high () if isinstance (res .ci_high (), Iterable ) else [res .ci_high ()]
35
+ return any ([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low , ci_high in zip (ci_low , ci_high )])
32
36
if res .test_value .type == "risk_ratio" :
33
37
return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
34
38
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
@@ -38,10 +42,15 @@ class NoEffect(CausalTestOutcome):
38
42
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
39
43
40
44
def apply (self , res : CausalTestResult , threshold : float = 1e-10 ) -> bool :
41
- print ("RESULT" , res )
42
- if res .test_value .type in {"ate" , "coefficient" }:
43
- return all ([ci_low < 0 < ci_high for ci_low , ci_high in zip (res .ci_low (), res .ci_high ())]) or all ([abs (v ) < 1e-10 for v in res .test_value .value ])
44
- # return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
45
+ if res .test_value .type == "ate" :
46
+ return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .test_value .value ) < 1e-10 )
47
+ if res .test_value .type == "coefficient" :
48
+ ci_low = res .ci_low () if isinstance (res .ci_low (), Iterable ) else [res .ci_low ()]
49
+ ci_high = res .ci_high () if isinstance (res .ci_high (), Iterable ) else [res .ci_high ()]
50
+ value = res .test_value .value if isinstance (res .ci_high (), Iterable ) else [res .test_value .value ]
51
+ return all ([ci_low < 0 < ci_high for ci_low , ci_high in zip (ci_low , ci_high )]) or all (
52
+ [abs (v ) < 1e-10 for v in value ]
53
+ )
45
54
if res .test_value .type == "risk_ratio" :
46
55
return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = 1e-10 )
47
56
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
0 commit comments