Skip to content

Commit 907079f

Browse files
committed
Pytest
1 parent 7c40163 commit 907079f

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

causal_testing/testing/causal_test_outcome.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ def apply(self, res: CausalTestResult) -> bool:
3232
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3333
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
3434

35-
def __str__(self):
36-
return "Changed"
37-
3835

3936
class NoEffect(CausalTestOutcome):
4037
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
@@ -46,9 +43,6 @@ def apply(self, res: CausalTestResult) -> bool:
4643
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
4744
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
4845

49-
def __str__(self):
50-
return "Unchanged"
51-
5246

5347
class ExactValue(SomeEffect):
5448
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
@@ -61,30 +55,33 @@ def __init__(self, value: float, tolerance: float = None):
6155
self.tolerance = tolerance
6256

6357
def apply(self, res: CausalTestResult) -> bool:
64-
super().apply()
58+
if res.ci_valid():
59+
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.tolerance)
6560
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
6661

6762
def __str__(self):
6863
return f"ExactValue: {self.value}±{self.tolerance}"
6964

7065

71-
class Positive(SomeEffect):
66+
class Positive(CausalTestOutcome):
7267
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
7368

7469
def apply(self, res: CausalTestResult) -> bool:
75-
super().apply()
70+
if res.ci_valid() and not super().apply(res):
71+
return False
7672
if res.test_value.type == "ate":
7773
return res.test_value.value > 0
7874
if res.test_value.type == "risk_ratio":
7975
return res.test_value.value > 1
8076
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
8177

8278

83-
class Negative(SomeEffect):
79+
class Negative(CausalTestOutcome):
8480
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
8581

8682
def apply(self, res: CausalTestResult) -> bool:
87-
super().apply()
83+
if res.ci_valid() and not super().apply(res):
84+
return False
8885
if res.test_value.type == "ate":
8986
return res.test_value.value < 0
9087
if res.test_value.type == "risk_ratio":

causal_testing/testing/causal_test_result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def ci_high(self):
8585
return max(self.confidence_intervals)
8686
return None
8787

88+
def ci_valid(self) -> bool:
89+
"""Return whether or not the result has valid confidence invervals"""
90+
return self.ci_low() and self.ci_high()
91+
8892
def summary(self):
8993
"""Summarise the causal test result as an intuitive sentence."""
9094
print(

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@ def test_Positive_fail(self):
8585
ev = Positive()
8686
self.assertFalse(ev.apply(ctr))
8787

88+
def test_Positive_fail_ci(self):
89+
test_value = TestValue(type="ate", value=0)
90+
ctr = CausalTestResult(
91+
estimator=self.estimator,
92+
test_value=test_value,
93+
confidence_intervals=[-1, 1],
94+
effect_modifier_configuration=None,
95+
)
96+
ev = Positive()
97+
self.assertFalse(ev.apply(ctr))
98+
8899
def test_Negative_pass(self):
89100
test_value = TestValue(type="ate", value=-5.05)
90101
ctr = CausalTestResult(
@@ -107,6 +118,17 @@ def test_Negative_fail(self):
107118
ev = Negative()
108119
self.assertFalse(ev.apply(ctr))
109120

121+
def test_Negative_fail_ci(self):
122+
test_value = TestValue(type="ate", value=0)
123+
ctr = CausalTestResult(
124+
estimator=self.estimator,
125+
test_value=test_value,
126+
confidence_intervals=[-1, 1],
127+
effect_modifier_configuration=None,
128+
)
129+
ev = Negative()
130+
self.assertFalse(ev.apply(ctr))
131+
110132
def test_exactValue_pass(self):
111133
test_value = TestValue(type="ate", value=5.05)
112134
ctr = CausalTestResult(

0 commit comments

Comments
 (0)