@@ -28,14 +28,13 @@ class SomeEffect(CausalTestOutcome):
28
28
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
29
29
30
30
def apply (self , res : CausalTestResult ) -> bool :
31
- if res .test_value .type == "ate" :
32
- return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
33
- if res .test_value .type == "coefficient" :
34
- ci_low = res .ci_low () if isinstance (res .ci_low (), Iterable ) else [res .ci_low ()]
35
- ci_high = res .ci_high () if isinstance (res .ci_high (), Iterable ) else [res .ci_high ()]
36
- return any (0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low , ci_high in zip (ci_low , ci_high ))
37
31
if res .test_value .type == "risk_ratio" :
38
- return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
32
+ return any (
33
+ 1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low , ci_high in zip (res .ci_low (), res .ci_high ()))
34
+ if res .test_value .type == "coefficient" or res .test_value .type == "ate" :
35
+ return any (
36
+ 0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low , ci_high in zip (res .ci_low (), res .ci_high ()))
37
+
39
38
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
40
39
41
40
@@ -52,23 +51,21 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
52
51
self .ctol = ctol
53
52
54
53
def apply (self , res : CausalTestResult ) -> bool :
55
- if res .test_value .type == "ate" :
56
- return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .test_value .value ) < self .atol )[0 ]
57
- if res .test_value .type == "coefficient" :
58
- ci_low = res .ci_low () if isinstance (res .ci_low (), Iterable ) else [res .ci_low ()]
59
- ci_high = res .ci_high () if isinstance (res .ci_high (), Iterable ) else [res .ci_high ()]
54
+ if res .test_value .type == "risk_ratio" :
55
+ return any (ci_low < 1 < ci_high or np .isclose (value , 1.0 , atol = self .atol ) for ci_low , ci_high , value in
56
+ zip (res .ci_low (), res .ci_high (), res .test_value .value ))
57
+ elif res .test_value .type == "coefficient" or res .test_value .type == "ate" :
60
58
value = res .test_value .value if isinstance (res .ci_high (), Iterable ) else [res .test_value .value ]
61
59
value = value [0 ] if isinstance (value [0 ], pd .Series ) else value
62
60
return (
63
- sum (
64
- not ((ci_low < 0 < ci_high ) or abs (v ) < self .atol )
65
- for ci_low , ci_high , v in zip (ci_low , ci_high , value )
66
- )
67
- / len (value )
68
- < self .ctol
61
+ sum (
62
+ not ((ci_low < 0 < ci_high ) or abs (v ) < self .atol )
63
+ for ci_low , ci_high , v in zip (res . ci_low (), res . ci_high () , value )
64
+ )
65
+ / len (value )
66
+ < self .ctol
69
67
)
70
- if res .test_value .type == "risk_ratio" :
71
- return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = self .atol )
68
+
72
69
raise ValueError (f"Test Value type { res .test_value .type } is not valid for this TestOutcome" )
73
70
74
71
0 commit comments