Skip to content

Commit 35bae1f

Browse files
SomeEffect and NoneEffect applys now work with pd.Series
1 parent 9641319 commit 35bae1f

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

causal_testing/testing/causal_test_outcome.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ class SomeEffect(CausalTestOutcome):
2828
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2929

3030
def apply(self, res: CausalTestResult) -> bool:
31-
if res.test_value.type == "ate":
32-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
33-
if res.test_value.type == "coefficient":
34-
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
35-
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
36-
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
3731
if res.test_value.type == "risk_ratio":
38-
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
32+
return any(
33+
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
34+
if res.test_value.type == "coefficient" or res.test_value.type == "ate":
35+
return any(
36+
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
37+
3938
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
4039

4140

@@ -52,23 +51,21 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5251
self.ctol = ctol
5352

5453
def apply(self, res: CausalTestResult) -> bool:
55-
if res.test_value.type == "ate":
56-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)[0]
57-
if res.test_value.type == "coefficient":
58-
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
59-
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
54+
if res.test_value.type == "risk_ratio":
55+
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
56+
zip(res.ci_low(), res.ci_high(), res.test_value.value))
57+
elif res.test_value.type == "coefficient" or res.test_value.type == "ate":
6058
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
6159
value = value[0] if isinstance(value[0], pd.Series) else value
6260
return (
63-
sum(
64-
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
65-
for ci_low, ci_high, v in zip(ci_low, ci_high, value)
66-
)
67-
/ len(value)
68-
< self.ctol
61+
sum(
62+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
63+
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
64+
)
65+
/ len(value)
66+
< self.ctol
6967
)
70-
if res.test_value.type == "risk_ratio":
71-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)
68+
7269
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7370

7471

0 commit comments

Comments
 (0)