Skip to content

Commit 5621236

Browse files
committed
causal_test_result + new dataclass for test values
1 parent 6e9a369 commit 5621236

File tree

3 files changed

+115
-97
lines changed

3 files changed

+115
-97
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from causal_testing.data_collection.data_collector import DataCollector
66
from causal_testing.specification.causal_specification import CausalSpecification
77
from causal_testing.testing.causal_test_case import CausalTestCase
8-
from causal_testing.testing.causal_test_outcome import CausalTestResult
8+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
99
from causal_testing.testing.estimators import Estimator
1010

1111
logger = logging.getLogger(__name__)
@@ -138,7 +138,7 @@ def execute_test(
138138
treatment_value=estimator.treatment_values,
139139
control_value=estimator.control_values,
140140
adjustment_set=estimator.adjustment_set,
141-
ate=cates_df,
141+
test_value=TestValue("ate", cates_df),
142142
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
143143
confidence_intervals=confidence_intervals,
144144
)
@@ -151,7 +151,7 @@ def execute_test(
151151
treatment_value=estimator.treatment_values,
152152
control_value=estimator.control_values,
153153
adjustment_set=estimator.adjustment_set,
154-
ate=risk_ratio,
154+
test_value=TestValue("risk_ratio", risk_ratio),
155155
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
156156
confidence_intervals=confidence_intervals,
157157
)
@@ -164,7 +164,7 @@ def execute_test(
164164
treatment_value=estimator.treatment_values,
165165
control_value=estimator.control_values,
166166
adjustment_set=estimator.adjustment_set,
167-
ate=ate,
167+
test_value=TestValue("ate", ate),
168168
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
169169
confidence_intervals=confidence_intervals,
170170
)
@@ -179,7 +179,7 @@ def execute_test(
179179
treatment_value=estimator.treatment_values,
180180
control_value=estimator.control_values,
181181
adjustment_set=estimator.adjustment_set,
182-
ate=ate,
182+
test_value=TestValue("ate", ate),
183183
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184184
confidence_intervals=confidence_intervals,
185185
)
Lines changed: 18 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Union
3-
2+
from causal_test_result import CausalTestResult
43
import numpy as np
54

6-
from causal_testing.specification.variable import Variable
7-
8-
9-
class CausalTestResult:
10-
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11-
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12-
confidence intervals."""
13-
14-
def __init__(
15-
self,
16-
treatment: tuple,
17-
outcome: tuple,
18-
treatment_value: Union[int, float, str],
19-
control_value: Union[int, float, str],
20-
adjustment_set: set,
21-
ate: float,
22-
confidence_intervals: [float, float] = None,
23-
effect_modifier_configuration: {Variable: Any} = None,
24-
):
25-
self.treatment = treatment
26-
self.outcome = outcome
27-
self.treatment_value = treatment_value
28-
self.control_value = control_value
29-
if adjustment_set:
30-
self.adjustment_set = adjustment_set
31-
else:
32-
self.adjustment_set = set()
33-
self.ate = ate
34-
self.confidence_intervals = confidence_intervals
35-
36-
if effect_modifier_configuration is not None:
37-
self.effect_modifier_configuration = effect_modifier_configuration
38-
else:
39-
self.effect_modifier_configuration = {}
40-
41-
def __str__(self):
42-
base_str = (
43-
f"Causal Test Result\n==============\n"
44-
f"Treatment: {self.treatment[0]}\n"
45-
f"Control value: {self.control_value}\n"
46-
f"Treatment value: {self.treatment_value}\n"
47-
f"Outcome: {self.outcome[0]}\n"
48-
f"Adjustment set: {self.adjustment_set}\n"
49-
f"ATE: {self.ate}\n"
50-
)
51-
confidence_str = ""
52-
if self.confidence_intervals:
53-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
54-
return base_str + confidence_str
55-
56-
def to_dict(self):
57-
base_dict = {
58-
"treatment": self.treatment[0],
59-
"control_value": self.control_value,
60-
"treatment_value": self.treatment_value,
61-
"outcome": self.outcome[0],
62-
"adjustment_set": self.adjustment_set,
63-
"ate": self.ate,
64-
}
65-
if self.confidence_intervals:
66-
base_dict["ci_low"] = min(self.confidence_intervals)
67-
base_dict["ci_high"] = max(self.confidence_intervals)
68-
return base_dict
69-
70-
def ci_low(self):
71-
"""Return the lower bracket of the confidence intervals."""
72-
if not self.confidence_intervals:
73-
return None
74-
return min(self.confidence_intervals)
75-
76-
def ci_high(self):
77-
"""Return the higher bracket of the confidence intervals."""
78-
if not self.confidence_intervals:
79-
return None
80-
return max(self.confidence_intervals)
81-
82-
def summary(self):
83-
"""Summarise the causal test result as an intuitive sentence."""
84-
print(
85-
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
86-
f"{self.treatment[0]}' = {self.treatment_value} is {self.ate} (95% confidence intervals: "
87-
f"{self.confidence_intervals})."
88-
)
89-
90-
915
class CausalTestOutcome(ABC):
926
"""An abstract class representing an expected causal effect."""
937

@@ -110,7 +24,7 @@ def __init__(self, value: float, tolerance: float = None):
11024
self.tolerance = tolerance
11125

11226
def apply(self, res: CausalTestResult) -> bool:
113-
return np.isclose(res.ate, self.value, atol=self.tolerance)
27+
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
11428

11529
def __str__(self):
11630
return f"ExactValue: {self.value}±{self.tolerance}"
@@ -121,22 +35,31 @@ class Positive(CausalTestOutcome):
12135

12236
def apply(self, res: CausalTestResult) -> bool:
12337
# TODO: confidence intervals?
124-
return res.ate > 0
38+
if res.test_value.type == "ate":
39+
return res.test_value.value > 0
40+
elif res.test_value.type == "risk_ratio":
41+
return res.test_value.value > 1
12542

12643

12744
class Negative(CausalTestOutcome):
12845
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
12946

13047
def apply(self, res: CausalTestResult) -> bool:
13148
# TODO: confidence intervals?
132-
return res.ate < 0
49+
if res.test_value.type == "ate":
50+
return res.test_value.value < 0
51+
elif res.test_value.type == "risk_ratio":
52+
return res.test_value.value < 1
13353

13454

13555
class SomeEffect(CausalTestOutcome):
13656
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
13757

13858
def apply(self, res: CausalTestResult) -> bool:
139-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
59+
if res.test_value.type == "ate":
60+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
61+
elif res.test_value.type == "risk_ratio":
62+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
14063

14164
def __str__(self):
14265
return "Changed"
@@ -146,7 +69,10 @@ class NoEffect(CausalTestOutcome):
14669
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
14770

14871
def apply(self, res: CausalTestResult) -> bool:
149-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
72+
if res.test_value.type == "ate":
73+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
74+
elif res.test_value.type == "risk_ratio":
75+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
15076

15177
def __str__(self):
15278
return "Unchanged"
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Any, Union
2+
from dataclasses import dataclass
3+
4+
from causal_testing.specification.variable import Variable
5+
6+
@dataclass
7+
class TestValue:
8+
"""A dataclass to hold both the type and value of a causal test result"""
9+
10+
type: str
11+
value: float
12+
13+
class CausalTestResult:
14+
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
15+
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
16+
confidence intervals."""
17+
18+
def __init__(
19+
self,
20+
treatment: tuple,
21+
outcome: tuple,
22+
treatment_value: Union[int, float, str],
23+
control_value: Union[int, float, str],
24+
adjustment_set: set,
25+
test_value: TestValue,
26+
confidence_intervals: [float, float] = None,
27+
effect_modifier_configuration: {Variable: Any} = None,
28+
):
29+
self.treatment = treatment
30+
self.outcome = outcome
31+
self.treatment_value = treatment_value
32+
self.control_value = control_value
33+
if adjustment_set:
34+
self.adjustment_set = adjustment_set
35+
else:
36+
self.adjustment_set = set()
37+
self.test_value = test_value
38+
self.confidence_intervals = confidence_intervals
39+
40+
if effect_modifier_configuration is not None:
41+
self.effect_modifier_configuration = effect_modifier_configuration
42+
else:
43+
self.effect_modifier_configuration = {}
44+
45+
def __str__(self):
46+
base_str = (
47+
f"Causal Test Result\n==============\n"
48+
f"Treatment: {self.treatment[0]}\n"
49+
f"Control value: {self.control_value}\n"
50+
f"Treatment value: {self.treatment_value}\n"
51+
f"Outcome: {self.outcome[0]}\n"
52+
f"Adjustment set: {self.adjustment_set}\n"
53+
f"{self.test_value.type}: {self.test_value.value}\n"
54+
)
55+
confidence_str = ""
56+
if self.confidence_intervals:
57+
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
58+
return base_str + confidence_str
59+
60+
def to_dict(self):
61+
base_dict = {
62+
"treatment": self.treatment[0],
63+
"control_value": self.control_value,
64+
"treatment_value": self.treatment_value,
65+
"outcome": self.outcome[0],
66+
"adjustment_set": self.adjustment_set,
67+
"test_value": self.test_value,
68+
}
69+
if self.confidence_intervals:
70+
base_dict["ci_low"] = min(self.confidence_intervals)
71+
base_dict["ci_high"] = max(self.confidence_intervals)
72+
return base_dict
73+
74+
def ci_low(self):
75+
"""Return the lower bracket of the confidence intervals."""
76+
if not self.confidence_intervals:
77+
return None
78+
return min(self.confidence_intervals)
79+
80+
def ci_high(self):
81+
"""Return the higher bracket of the confidence intervals."""
82+
if not self.confidence_intervals:
83+
return None
84+
return max(self.confidence_intervals)
85+
86+
def summary(self):
87+
"""Summarise the causal test result as an intuitive sentence."""
88+
print(
89+
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
90+
f"{self.treatment[0]}' = {self.treatment_value} is {self.test_value.value} (95% confidence intervals: "
91+
f"{self.confidence_intervals})."
92+
)

0 commit comments

Comments
 (0)