Skip to content

Commit b5c05d3

Browse files
committed
Higher coverage
1 parent 17185f4 commit b5c05d3

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

causal_testing/testing/causal_test_result.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,22 @@ def to_dict(self):
7474
"adjustment_set": self.adjustment_set,
7575
"test_value": self.test_value,
7676
}
77-
if self.confidence_intervals:
77+
if self.confidence_intervals and all(self.confidence_intervals):
7878
base_dict["ci_low"] = min(self.confidence_intervals)
7979
base_dict["ci_high"] = max(self.confidence_intervals)
8080
return base_dict
8181

8282
def ci_low(self):
8383
"""Return the lower bracket of the confidence intervals."""
84-
if not self.confidence_intervals:
85-
return None
86-
if any([x is None for x in self.confidence_intervals]):
87-
return None
88-
return min(self.confidence_intervals)
84+
if self.confidence_intervals and all(self.confidence_intervals):
85+
return min(self.confidence_intervals)
86+
return None
8987

9088
def ci_high(self):
9189
"""Return the higher bracket of the confidence intervals."""
92-
if not self.confidence_intervals:
93-
return None
94-
if any([x is None for x in self.confidence_intervals]):
95-
return None
96-
return max(self.confidence_intervals)
90+
if self.confidence_intervals and all(self.confidence_intervals):
91+
return max(self.confidence_intervals)
92+
return None
9793

9894
def summary(self):
9995
"""Summarise the causal test result as an intuitive sentence."""

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,51 @@
66
class TestCausalTestOutcome(unittest.TestCase):
77
"""Test the TestCausalTestOutcome basic methods."""
88

9+
def test_None_ci(self):
10+
test_value = TestValue(type="ate", value=0)
11+
ctr = CausalTestResult(
12+
treatment="A",
13+
outcome="A",
14+
treatment_value=1,
15+
control_value=0,
16+
adjustment_set={},
17+
test_value=test_value,
18+
confidence_intervals=[None, None],
19+
effect_modifier_configuration=None,
20+
)
21+
22+
self.assertIsNone(ctr.ci_low())
23+
self.assertIsNone(ctr.ci_high())
24+
self.assertEqual(ctr.to_dict(),
25+
{"treatment": "A",
26+
"control_value": 0,
27+
"treatment_value": 1,
28+
"outcome": "A",
29+
"adjustment_set": set(),
30+
"test_value": test_value})
31+
932
def test_empty_adjustment_set(self):
33+
test_value = TestValue(type="ate", value=0)
1034
ctr = CausalTestResult(
1135
treatment="A",
1236
outcome="A",
1337
treatment_value=1,
1438
control_value=0,
1539
adjustment_set={},
16-
test_value=0,
40+
test_value=test_value,
1741
confidence_intervals=None,
1842
effect_modifier_configuration=None,
1943
)
2044

2145
self.assertIsNone(ctr.ci_low())
2246
self.assertIsNone(ctr.ci_high())
47+
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
48+
"Treatment: A\n"
49+
"Control value: 0\n"
50+
"Treatment value: 1\n"
51+
"Outcome: A\n"
52+
"Adjustment set: set()\n"
53+
"ate: 0\n" ))
2354

2455
def test_exactValue_pass(self):
2556
test_value = TestValue(type="ate", value=5.05)
@@ -80,3 +111,20 @@ def test_someEffect_fail(self):
80111
)
81112
ev = SomeEffect()
82113
self.assertFalse(ev.apply(ctr))
114+
self.assertEqual(str(ctr), ("Causal Test Result\n==============\n"
115+
"Treatment: A\n"
116+
"Control value: 0\n"
117+
"Treatment value: 1\n"
118+
"Outcome: A\n"
119+
"Adjustment set: set()\n"
120+
"ate: 0\n"
121+
"Confidence intervals: [-0.1, 0.2]\n" ))
122+
self.assertEqual(ctr.to_dict(),
123+
{"treatment": "A",
124+
"control_value": 0,
125+
"treatment_value": 1,
126+
"outcome": "A",
127+
"adjustment_set": set(),
128+
"test_value": test_value,
129+
"ci_low": -0.1,
130+
"ci_high": 0.2})

0 commit comments

Comments
 (0)