Skip to content

Commit 01e6c8b

Browse files
Merge branch 'main' into base-causal-test-case
# Conflicts: # tests/testing_tests/test_causal_test_case.py # tests/testing_tests/test_causal_test_engine.py # tests/testing_tests/test_causal_test_outcome.py # tests/testing_tests/test_estimators.py
2 parents 78c1eee + 8d77dea commit 01e6c8b

File tree

13 files changed

+184
-136
lines changed

13 files changed

+184
-136
lines changed

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
177177

178178
result_string = str()
179179
if causal_test_result.ci_low() and causal_test_result.ci_high():
180-
result_string = (
181-
f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
182-
)
180+
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
183181
else:
184-
result_string = causal_test_result.ate
182+
result_string = causal_test_result.test_value.value
185183
if f_flag:
186184
assert test_passes, (
187185
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
@@ -192,7 +190,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
192190
logger.warning(
193191
" FAILED- expected %s, got %s",
194192
causal_test_case.expected_causal_effect,
195-
causal_test_result.ate,
193+
causal_test_result.test_value.value,
196194
)
197195
return failed
198196

causal_testing/testing/causal_test_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from causal_testing.data_collection.data_collector import DataCollector
66
from causal_testing.specification.causal_specification import CausalSpecification
77
from causal_testing.testing.causal_test_case import CausalTestCase
8-
from causal_testing.testing.causal_test_outcome import CausalTestResult
8+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
99
from causal_testing.testing.estimators import Estimator
1010
from causal_testing.testing.base_test_case import BaseTestCase
1111
from causal_testing.testing.causal_test_suite import CausalTestSuite
@@ -178,7 +178,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
178178
treatment_value=estimator.treatment_values,
179179
control_value=estimator.control_values,
180180
adjustment_set=estimator.adjustment_set,
181-
ate=cates_df,
181+
test_value=TestValue("ate", cates_df),
182182
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
183183
confidence_intervals=confidence_intervals,
184184
)
@@ -191,7 +191,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
191191
treatment_value=estimator.treatment_values,
192192
control_value=estimator.control_values,
193193
adjustment_set=estimator.adjustment_set,
194-
ate=risk_ratio,
194+
test_value=TestValue("risk_ratio", risk_ratio),
195195
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
196196
confidence_intervals=confidence_intervals,
197197
)
@@ -204,7 +204,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
204204
treatment_value=estimator.treatment_values,
205205
control_value=estimator.control_values,
206206
adjustment_set=estimator.adjustment_set,
207-
ate=ate,
207+
test_value=TestValue("ate", ate),
208208
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
209209
confidence_intervals=confidence_intervals,
210210
)
@@ -219,7 +219,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
219219
treatment_value=estimator.treatment_values,
220220
control_value=estimator.control_values,
221221
adjustment_set=estimator.adjustment_set,
222-
ate=ate,
222+
test_value=TestValue("ate", ate),
223223
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
224224
confidence_intervals=confidence_intervals,
225225
)
Lines changed: 18 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Union
3-
2+
from causal_testing.testing.causal_test_result import CausalTestResult
43
import numpy as np
54

6-
from causal_testing.specification.variable import Variable
7-
8-
9-
class CausalTestResult:
10-
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11-
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12-
confidence intervals."""
13-
14-
def __init__(
15-
self,
16-
treatment: tuple,
17-
outcome: tuple,
18-
treatment_value: Union[int, float, str],
19-
control_value: Union[int, float, str],
20-
adjustment_set: set,
21-
ate: float,
22-
confidence_intervals: [float, float] = None,
23-
effect_modifier_configuration: {Variable: Any} = None,
24-
):
25-
self.treatment = treatment
26-
self.outcome = outcome
27-
self.treatment_value = treatment_value
28-
self.control_value = control_value
29-
if adjustment_set:
30-
self.adjustment_set = adjustment_set
31-
else:
32-
self.adjustment_set = set()
33-
self.ate = ate
34-
self.confidence_intervals = confidence_intervals
35-
36-
if effect_modifier_configuration is not None:
37-
self.effect_modifier_configuration = effect_modifier_configuration
38-
else:
39-
self.effect_modifier_configuration = {}
40-
41-
def __str__(self):
42-
base_str = (
43-
f"Causal Test Result\n==============\n"
44-
f"Treatment: {self.treatment[0]}\n"
45-
f"Control value: {self.control_value}\n"
46-
f"Treatment value: {self.treatment_value}\n"
47-
f"Outcome: {self.outcome[0]}\n"
48-
f"Adjustment set: {self.adjustment_set}\n"
49-
f"ATE: {self.ate}\n"
50-
)
51-
confidence_str = ""
52-
if self.confidence_intervals:
53-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
54-
return base_str + confidence_str
55-
56-
def to_dict(self):
57-
base_dict = {
58-
"treatment": self.treatment[0],
59-
"control_value": self.control_value,
60-
"treatment_value": self.treatment_value,
61-
"outcome": self.outcome[0],
62-
"adjustment_set": self.adjustment_set,
63-
"ate": self.ate,
64-
}
65-
if self.confidence_intervals:
66-
base_dict["ci_low"] = min(self.confidence_intervals)
67-
base_dict["ci_high"] = max(self.confidence_intervals)
68-
return base_dict
69-
70-
def ci_low(self):
71-
"""Return the lower bracket of the confidence intervals."""
72-
if not self.confidence_intervals:
73-
return None
74-
return min(self.confidence_intervals)
75-
76-
def ci_high(self):
77-
"""Return the higher bracket of the confidence intervals."""
78-
if not self.confidence_intervals:
79-
return None
80-
return max(self.confidence_intervals)
81-
82-
def summary(self):
83-
"""Summarise the causal test result as an intuitive sentence."""
84-
print(
85-
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
86-
f"{self.treatment[0]}' = {self.treatment_value} is {self.ate} (95% confidence intervals: "
87-
f"{self.confidence_intervals})."
88-
)
89-
905

916
class CausalTestOutcome(ABC):
927
"""An abstract class representing an expected causal effect."""
@@ -110,7 +25,7 @@ def __init__(self, value: float, tolerance: float = None):
11025
self.tolerance = tolerance
11126

11227
def apply(self, res: CausalTestResult) -> bool:
113-
return np.isclose(res.ate, self.value, atol=self.tolerance)
28+
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
11429

11530
def __str__(self):
11631
return f"ExactValue: {self.value}±{self.tolerance}"
@@ -121,22 +36,31 @@ class Positive(CausalTestOutcome):
12136

12237
def apply(self, res: CausalTestResult) -> bool:
12338
# TODO: confidence intervals?
124-
return res.ate > 0
39+
if res.test_value.type == "ate":
40+
return res.test_value.value > 0
41+
elif res.test_value.type == "risk_ratio":
42+
return res.test_value.value > 1
12543

12644

12745
class Negative(CausalTestOutcome):
12846
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
12947

13048
def apply(self, res: CausalTestResult) -> bool:
13149
# TODO: confidence intervals?
132-
return res.ate < 0
50+
if res.test_value.type == "ate":
51+
return res.test_value.value < 0
52+
elif res.test_value.type == "risk_ratio":
53+
return res.test_value.value < 1
13354

13455

13556
class SomeEffect(CausalTestOutcome):
13657
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
13758

13859
def apply(self, res: CausalTestResult) -> bool:
139-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
60+
if res.test_value.type == "ate":
61+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
62+
elif res.test_value.type == "risk_ratio":
63+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
14064

14165
def __str__(self):
14266
return "Changed"
@@ -146,7 +70,10 @@ class NoEffect(CausalTestOutcome):
14670
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
14771

14872
def apply(self, res: CausalTestResult) -> bool:
149-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
73+
if res.test_value.type == "ate":
74+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
75+
elif res.test_value.type == "risk_ratio":
76+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
15077

15178
def __str__(self):
15279
return "Unchanged"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Union
2+
from dataclasses import dataclass
3+
4+
from causal_testing.specification.variable import Variable
5+
6+
7+
@dataclass
8+
class TestValue:
9+
"""A dataclass to hold both the type and value of a causal test result"""
10+
11+
type: str
12+
value: float
13+
14+
15+
class CausalTestResult:
16+
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
17+
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
18+
confidence intervals."""
19+
20+
def __init__(
21+
self,
22+
treatment: tuple,
23+
outcome: tuple,
24+
treatment_value: Union[int, float, str],
25+
control_value: Union[int, float, str],
26+
adjustment_set: set,
27+
test_value: TestValue,
28+
confidence_intervals: [float, float] = None,
29+
effect_modifier_configuration: {Variable: Any} = None,
30+
):
31+
self.treatment = treatment
32+
self.outcome = outcome
33+
self.treatment_value = treatment_value
34+
self.control_value = control_value
35+
if adjustment_set:
36+
self.adjustment_set = adjustment_set
37+
else:
38+
self.adjustment_set = set()
39+
self.test_value = test_value
40+
self.confidence_intervals = confidence_intervals
41+
42+
if effect_modifier_configuration is not None:
43+
self.effect_modifier_configuration = effect_modifier_configuration
44+
else:
45+
self.effect_modifier_configuration = {}
46+
47+
def __str__(self):
48+
base_str = (
49+
f"Causal Test Result\n==============\n"
50+
f"Treatment: {self.treatment[0]}\n"
51+
f"Control value: {self.control_value}\n"
52+
f"Treatment value: {self.treatment_value}\n"
53+
f"Outcome: {self.outcome[0]}\n"
54+
f"Adjustment set: {self.adjustment_set}\n"
55+
f"{self.test_value.type}: {self.test_value.value}\n"
56+
)
57+
confidence_str = ""
58+
if self.confidence_intervals:
59+
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
60+
return base_str + confidence_str
61+
62+
def to_dict(self):
63+
base_dict = {
64+
"treatment": self.treatment[0],
65+
"control_value": self.control_value,
66+
"treatment_value": self.treatment_value,
67+
"outcome": self.outcome[0],
68+
"adjustment_set": self.adjustment_set,
69+
"test_value": self.test_value,
70+
}
71+
if self.confidence_intervals:
72+
base_dict["ci_low"] = min(self.confidence_intervals)
73+
base_dict["ci_high"] = max(self.confidence_intervals)
74+
return base_dict
75+
76+
def ci_low(self):
77+
"""Return the lower bracket of the confidence intervals."""
78+
if not self.confidence_intervals:
79+
return None
80+
return min(self.confidence_intervals)
81+
82+
def ci_high(self):
83+
"""Return the higher bracket of the confidence intervals."""
84+
if not self.confidence_intervals:
85+
return None
86+
return max(self.confidence_intervals)
87+
88+
def summary(self):
89+
"""Summarise the causal test result as an intuitive sentence."""
90+
print(
91+
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
92+
f"{self.treatment[0]}' = {self.treatment_value} is {self.test_value.value} (95% confidence intervals: "
93+
f"{self.confidence_intervals})."
94+
)

0 commit comments

Comments
 (0)