Skip to content

Commit 6515afc

Browse files
committed
pytests
1 parent c867ac6 commit 6515afc

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

causal_testing/testing/causal_effect.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,29 @@ class ExactValue(CausalEffect):
7979
def __init__(self, value: float, atol: float = None, ci_low: float = None, ci_high: float = None):
8080
if (ci_low is not None) ^ (ci_high is not None):
8181
raise ValueError("If specifying confidence intervals, must specify `ci_low` and `ci_high` parameters.")
82+
if atol is not None and atol < 0:
83+
raise ValueError("Tolerance must be an absolute (positive) value.")
84+
8285
self.value = value
8386
self.ci_low = ci_low
8487
self.ci_high = ci_high
8588
self.atol = atol if atol is not None else abs(value * 0.05)
86-
if self.atol < 0:
87-
raise ValueError("Tolerance must be an absolute (positive) value.")
89+
90+
if self.ci_low is not None and self.ci_high is not None:
91+
if not self.ci_low <= self.value <= self.ci_high:
92+
raise ValueError("Specified value falls outside the specified confidence intervals.")
93+
if self.value - self.atol < self.ci_low or self.value + self.atol > self.ci_high:
94+
raise ValueError(
95+
"Arithmetic tolerance falls outside the confidence intervals. Try specifyin a smaller value of atol."
96+
)
8897

8998
def apply(self, res: CausalTestResult) -> bool:
9099
close = np.isclose(res.test_value.value, self.value, atol=self.atol)
91100
if res.ci_valid() and self.ci_low is not None and self.ci_high is not None:
92-
return close and self.ci_low <= res.ci_low() and self.ci_high >= res.ci_high()
101+
return all(
102+
close and ci_low <= ci_low and ci_high >= ci_high
103+
for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
104+
)
93105
return close
94106

95107
def __str__(self):

tests/testing_tests/test_causal_test_outcome.py renamed to tests/testing_tests/test_causal_effect.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,28 @@ def test_exactValue_pass_ci(self):
181181
ev = ExactValue(5, 0.1)
182182
self.assertTrue(ev.apply(ctr))
183183

184+
def test_exactValue_ci_pass_ci(self):
185+
test_value = TestValue(type="ate", value=pd.Series(5.05))
186+
ctr = CausalTestResult(
187+
estimator=self.estimator,
188+
test_value=test_value,
189+
confidence_intervals=[pd.Series(4.1), pd.Series(5.9)],
190+
effect_modifier_configuration=None,
191+
)
192+
ev = ExactValue(5, ci_low=4, ci_high=6)
193+
self.assertTrue(ev.apply(ctr))
194+
195+
def test_exactValue_ci_fail_ci(self):
196+
test_value = TestValue(type="ate", value=pd.Series(5.05))
197+
ctr = CausalTestResult(
198+
estimator=self.estimator,
199+
test_value=test_value,
200+
confidence_intervals=[pd.Series(3.9), pd.Series(6.1)],
201+
effect_modifier_configuration=None,
202+
)
203+
ev = ExactValue(5, ci_low=4, ci_high=6)
204+
self.assertTrue(ev.apply(ctr))
205+
184206
def test_exactValue_fail(self):
185207
test_value = TestValue(type="ate", value=pd.Series(0))
186208
ctr = CausalTestResult(
@@ -196,6 +218,22 @@ def test_invalid_atol(self):
196218
with self.assertRaises(ValueError):
197219
ExactValue(5, -0.1)
198220

221+
def test_unspecified_ci_high(self):
222+
with self.assertRaises(ValueError):
223+
ExactValue(5, ci_low=-0.1)
224+
225+
def test_unspecified_ci_low(self):
226+
with self.assertRaises(ValueError):
227+
ExactValue(5, ci_high=-0.1)
228+
229+
def test_invalid_ci_range(self):
230+
with self.assertRaises(ValueError):
231+
ExactValue(5, ci_low=6, ci_high=7, atol=0.05)
232+
233+
def test_invalid_ci_atol(self):
234+
with self.assertRaises(ValueError):
235+
ExactValue(1000, ci_low=1001, ci_high=1002, atol=0.05)
236+
199237
def test_invalid(self):
200238
test_value = TestValue(type="invalid", value=pd.Series(5.05))
201239
ctr = CausalTestResult(

0 commit comments

Comments
 (0)