Skip to content

Commit eb96ea7

Browse files
authored
Merge pull request #321 from CITCOM-project/jmafoster1/exact-value-confidence-intervals
Jmafoster1/exact value confidence intervals
2 parents 8419eb2 + 192cfca commit eb96ea7

File tree

12 files changed

+101
-42
lines changed

12 files changed

+101
-42
lines changed

causal_testing/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from causal_testing.specification.causal_specification import CausalSpecification
1818
from causal_testing.testing.causal_test_case import CausalTestCase
1919
from causal_testing.testing.base_test_case import BaseTestCase
20-
from causal_testing.testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative
20+
from causal_testing.testing.causal_effect import NoEffect, SomeEffect, Positive, Negative
2121
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
2222
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
2323
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator

causal_testing/testing/causal_test_outcome.py renamed to causal_testing/testing/causal_effect.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=too-few-public-methods
2-
"""This module contains the CausalTestOutcome abstract class, as well as the concrete extension classes:
2+
"""This module contains the CausalEffect abstract class, as well as the concrete extension classes:
33
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
44

55
from abc import ABC, abstractmethod
@@ -9,7 +9,7 @@
99
from causal_testing.testing.causal_test_result import CausalTestResult
1010

1111

12-
class CausalTestOutcome(ABC):
12+
class CausalEffect(ABC):
1313
"""An abstract class representing an expected causal effect."""
1414

1515
@abstractmethod
@@ -23,8 +23,8 @@ def __str__(self) -> str:
2323
return type(self).__name__
2424

2525

26-
class SomeEffect(CausalTestOutcome):
27-
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
26+
class SomeEffect(CausalEffect):
27+
"""An extension of CausalEffect representing that the expected causal effect should not be zero."""
2828

2929
def apply(self, res: CausalTestResult) -> bool:
3030
if res.ci_low() is None or res.ci_high() is None:
@@ -38,11 +38,11 @@ def apply(self, res: CausalTestResult) -> bool:
3838
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
3939
)
4040

41-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
41+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this CausalEffect")
4242

4343

44-
class NoEffect(CausalTestOutcome):
45-
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
44+
class NoEffect(CausalEffect):
45+
"""An extension of CausalEffect representing that the expected causal effect should be zero."""
4646

4747
def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
4848
"""
@@ -70,58 +70,69 @@ def apply(self, res: CausalTestResult) -> bool:
7070
< self.ctol
7171
)
7272

73-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
73+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this CausalEffect")
7474

7575

76-
class ExactValue(SomeEffect):
77-
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
76+
class ExactValue(CausalEffect):
77+
"""An extension of CausalEffect representing that the expected causal effect should be a specific value."""
78+
79+
def __init__(self, value: float, atol: float = None, ci_low: float = None, ci_high: float = None):
80+
if (ci_low is not None) ^ (ci_high is not None):
81+
raise ValueError("If specifying confidence intervals, must specify `ci_low` and `ci_high` parameters.")
82+
if atol is not None and atol < 0:
83+
raise ValueError("Tolerance must be an absolute (positive) value.")
7884

79-
def __init__(self, value: float, atol: float = None):
8085
self.value = value
81-
if atol is None:
82-
self.atol = abs(value * 0.05)
83-
else:
84-
self.atol = atol
85-
if self.atol < 0:
86-
raise ValueError("Tolerance must be an absolute value.")
86+
self.ci_low = ci_low
87+
self.ci_high = ci_high
88+
self.atol = atol if atol is not None else abs(value * 0.05)
89+
90+
if self.ci_low is not None and self.ci_high is not None:
91+
if not self.ci_low <= self.value <= self.ci_high:
92+
raise ValueError("Specified value falls outside the specified confidence intervals.")
93+
if self.value - self.atol < self.ci_low or self.value + self.atol > self.ci_high:
94+
raise ValueError(
95+
"Arithmetic tolerance falls outside the confidence intervals."
96+
"Try specifying a smaller value of atol."
97+
)
8798

8899
def apply(self, res: CausalTestResult) -> bool:
89-
if res.ci_valid():
90-
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.atol)
91-
return np.isclose(res.test_value.value, self.value, atol=self.atol)
100+
close = np.isclose(res.test_value.value, self.value, atol=self.atol)
101+
if res.ci_valid() and self.ci_low is not None and self.ci_high is not None:
102+
return all(
103+
close and self.ci_low <= ci_low and self.ci_high >= ci_high
104+
for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
105+
)
106+
return close
92107

93108
def __str__(self):
94109
return f"ExactValue: {self.value}±{self.atol}"
95110

96111

97112
class Positive(SomeEffect):
98-
"""An extension of TestOutcome representing that the expected causal effect should be positive.
113+
"""An extension of CausalEffect representing that the expected causal effect should be positive.
99114
Currently only single values are supported for the test value"""
100115

101116
def apply(self, res: CausalTestResult) -> bool:
102-
if res.ci_valid() and not super().apply(res):
103-
return False
104117
if len(res.test_value.value) > 1:
105118
raise ValueError("Positive Effects are currently only supported on single float datatypes")
106119
if res.test_value.type in {"ate", "coefficient"}:
107120
return bool(res.test_value.value[0] > 0)
108121
if res.test_value.type == "risk_ratio":
109122
return bool(res.test_value.value[0] > 1)
110-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
123+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this CausalEffect")
111124

112125

113126
class Negative(SomeEffect):
114-
"""An extension of TestOutcome representing that the expected causal effect should be negative.
127+
"""An extension of CausalEffect representing that the expected causal effect should be negative.
115128
Currently only single values are supported for the test value"""
116129

117130
def apply(self, res: CausalTestResult) -> bool:
118-
if res.ci_valid() and not super().apply(res):
119-
return False
120131
if len(res.test_value.value) > 1:
121132
raise ValueError("Negative Effects are currently only supported on single float datatypes")
122133
if res.test_value.type in {"ate", "coefficient"}:
123134
return bool(res.test_value.value[0] < 0)
124135
if res.test_value.type == "risk_ratio":
125136
return bool(res.test_value.value[0] < 1)
126137
# Dead code but necessary for pylint
127-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
138+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this CausalEffect")

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from causal_testing.specification.variable import Variable
7-
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
7+
from causal_testing.testing.causal_effect import CausalEffect
88
from causal_testing.testing.base_test_case import BaseTestCase
99
from causal_testing.estimation.abstract_estimator import Estimator
1010
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
@@ -26,7 +26,7 @@ def __init__(
2626
# pylint: disable=too-many-arguments
2727
self,
2828
base_test_case: BaseTestCase,
29-
expected_causal_effect: CausalTestOutcome,
29+
expected_causal_effect: CausalEffect,
3030
estimate_type: str = "ate",
3131
estimate_params: dict = None,
3232
effect_modifier_configuration: dict[Variable:Any] = None,

dafni/main_dafni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pandas as pd
1111
from causal_testing.specification.scenario import Scenario
1212
from causal_testing.specification.variable import Input, Output
13-
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect, SomeEffect
13+
from causal_testing.testing.causal_effect import Positive, Negative, NoEffect, SomeEffect
1414
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1515
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
1616
from causal_testing.json_front.json_class import JsonUtility

docs/source/usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ the given output and input and the desired effect. This information is the minim
4444
4545
from causal_testing.testing.base_test_case import BaseTestCase
4646
from causal_testing.testing.causal_test_case import CausalTestCase
47-
from causal_testing.testing.causal_test_outcome import Positive
47+
from causal_testing.testing.causal_effect import Positive
4848
from causal_testing.testing.effect import Effect
4949
5050
base_test_case = BaseTestCase(

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from causal_testing.specification.variable import Input, Output
99
from causal_testing.testing.causal_test_case import CausalTestCase
10-
from causal_testing.testing.causal_test_outcome import Positive
10+
from causal_testing.testing.causal_effect import Positive
1111
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1212
from causal_testing.testing.base_test_case import BaseTestCase
1313

examples/covasim_/vaccinating_elderly/example_vaccine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from causal_testing.specification.variable import Input, Output
88
from causal_testing.specification.causal_specification import CausalSpecification
99
from causal_testing.testing.causal_test_case import CausalTestCase
10-
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
10+
from causal_testing.testing.causal_effect import Positive, Negative, NoEffect
1111
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1212
from causal_testing.testing.base_test_case import BaseTestCase
1313

examples/lr91/example_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from causal_testing.specification.variable import Input, Output
77
from causal_testing.specification.causal_specification import CausalSpecification
88
from causal_testing.testing.causal_test_case import CausalTestCase
9-
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
9+
from causal_testing.testing.causal_effect import Positive, Negative, NoEffect
1010
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1111
from causal_testing.testing.base_test_case import BaseTestCase
1212
from matplotlib.pyplot import rcParams

examples/poisson-line-process/example_pure_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from causal_testing.specification.variable import Input, Output
99
from causal_testing.specification.causal_specification import CausalSpecification
1010
from causal_testing.testing.causal_test_case import CausalTestCase
11-
from causal_testing.testing.causal_test_outcome import ExactValue, Positive
11+
from causal_testing.testing.causal_effect import ExactValue, Positive
1212
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1313
from causal_testing.estimation.abstract_estimator import Estimator
1414
from causal_testing.testing.base_test_case import BaseTestCase

tests/testing_tests/test_causal_test_outcome.py renamed to tests/testing_tests/test_causal_effect.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import unittest
22
import pandas as pd
3-
from causal_testing.testing.causal_test_outcome import ExactValue, SomeEffect, Positive, Negative, NoEffect
3+
from causal_testing.testing.causal_effect import ExactValue, SomeEffect, Positive, Negative, NoEffect
44
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
55
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
66
from causal_testing.utils.validation import CausalValidator
77
from causal_testing.testing.base_test_case import BaseTestCase
88
from causal_testing.specification.variable import Input, Output
99

1010

11-
class TestCausalTestOutcome(unittest.TestCase):
12-
"""Test the TestCausalTestOutcome basic methods."""
11+
class TestCausalEffect(unittest.TestCase):
12+
"""Test the TestCausalEffect basic methods."""
1313

1414
def setUp(self) -> None:
1515
base_test_case = BaseTestCase(Input("A", float), Output("A", float))
@@ -181,6 +181,28 @@ def test_exactValue_pass_ci(self):
181181
ev = ExactValue(5, 0.1)
182182
self.assertTrue(ev.apply(ctr))
183183

184+
def test_exactValue_ci_pass_ci(self):
185+
test_value = TestValue(type="ate", value=pd.Series(5.05))
186+
ctr = CausalTestResult(
187+
estimator=self.estimator,
188+
test_value=test_value,
189+
confidence_intervals=[pd.Series(4.1), pd.Series(5.9)],
190+
effect_modifier_configuration=None,
191+
)
192+
ev = ExactValue(5, ci_low=4, ci_high=6)
193+
self.assertTrue(ev.apply(ctr))
194+
195+
def test_exactValue_ci_fail_ci(self):
196+
test_value = TestValue(type="ate", value=pd.Series(5.05))
197+
ctr = CausalTestResult(
198+
estimator=self.estimator,
199+
test_value=test_value,
200+
confidence_intervals=[pd.Series(3.9), pd.Series(6.1)],
201+
effect_modifier_configuration=None,
202+
)
203+
ev = ExactValue(5, ci_low=4, ci_high=6)
204+
self.assertFalse(ev.apply(ctr))
205+
184206
def test_exactValue_fail(self):
185207
test_value = TestValue(type="ate", value=pd.Series(0))
186208
ctr = CausalTestResult(
@@ -196,6 +218,22 @@ def test_invalid_atol(self):
196218
with self.assertRaises(ValueError):
197219
ExactValue(5, -0.1)
198220

221+
def test_unspecified_ci_high(self):
222+
with self.assertRaises(ValueError):
223+
ExactValue(5, ci_low=-0.1)
224+
225+
def test_unspecified_ci_low(self):
226+
with self.assertRaises(ValueError):
227+
ExactValue(5, ci_high=-0.1)
228+
229+
def test_invalid_ci_range(self):
230+
with self.assertRaises(ValueError):
231+
ExactValue(5, ci_low=6, ci_high=7, atol=0.05)
232+
233+
def test_invalid_ci_atol(self):
234+
with self.assertRaises(ValueError):
235+
ExactValue(1000, ci_low=999, ci_high=1001, atol=50)
236+
199237
def test_invalid(self):
200238
test_value = TestValue(type="invalid", value=pd.Series(5.05))
201239
ctr = CausalTestResult(
@@ -257,6 +295,16 @@ def test_someEffect_fail(self):
257295
self.assertFalse(SomeEffect().apply(ctr))
258296
self.assertTrue(NoEffect().apply(ctr))
259297

298+
def test_someEffect_None(self):
299+
test_value = TestValue(type="ate", value=pd.Series(0))
300+
ctr = CausalTestResult(
301+
estimator=self.estimator,
302+
test_value=test_value,
303+
confidence_intervals=None,
304+
effect_modifier_configuration=None,
305+
)
306+
self.assertEqual(SomeEffect().apply(ctr), None)
307+
260308
def test_someEffect_str(self):
261309
test_value = TestValue(type="ate", value=0)
262310
ctr = CausalTestResult(

0 commit comments

Comments
 (0)