Skip to content

Commit 81a635c

Browse files
committed
Merge branch 'main' of github.com:CITCOM-project/CausalTestingFramework into json-scenario
2 parents 7494aed + 8d77dea commit 81a635c

File tree

20 files changed

+598
-468
lines changed

20 files changed

+598
-468
lines changed

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
176176

177177
result_string = str()
178178
if causal_test_result.ci_low() and causal_test_result.ci_high():
179-
result_string = (
180-
f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
181-
)
179+
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
182180
else:
183-
result_string = causal_test_result.ate
181+
result_string = causal_test_result.test_value.value
184182
if f_flag:
185183
assert test_passes, (
186184
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
@@ -191,7 +189,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
191189
logger.warning(
192190
" FAILED- expected %s, got %s",
193191
causal_test_case.expected_causal_effect,
194-
causal_test_result.ate,
192+
causal_test_result.test_value.value,
195193
)
196194
return failed
197195

causal_testing/testing/causal_test_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from causal_testing.data_collection.data_collector import DataCollector
66
from causal_testing.specification.causal_specification import CausalSpecification
77
from causal_testing.testing.causal_test_case import CausalTestCase
8-
from causal_testing.testing.causal_test_outcome import CausalTestResult
8+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
99
from causal_testing.testing.estimators import Estimator
1010

1111
logger = logging.getLogger(__name__)
@@ -138,7 +138,7 @@ def execute_test(
138138
treatment_value=estimator.treatment_values,
139139
control_value=estimator.control_values,
140140
adjustment_set=estimator.adjustment_set,
141-
ate=cates_df,
141+
test_value=TestValue("ate", cates_df),
142142
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
143143
confidence_intervals=confidence_intervals,
144144
)
@@ -151,7 +151,7 @@ def execute_test(
151151
treatment_value=estimator.treatment_values,
152152
control_value=estimator.control_values,
153153
adjustment_set=estimator.adjustment_set,
154-
ate=risk_ratio,
154+
test_value=TestValue("risk_ratio", risk_ratio),
155155
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
156156
confidence_intervals=confidence_intervals,
157157
)
@@ -164,7 +164,7 @@ def execute_test(
164164
treatment_value=estimator.treatment_values,
165165
control_value=estimator.control_values,
166166
adjustment_set=estimator.adjustment_set,
167-
ate=ate,
167+
test_value=TestValue("ate", ate),
168168
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
169169
confidence_intervals=confidence_intervals,
170170
)
@@ -179,7 +179,7 @@ def execute_test(
179179
treatment_value=estimator.treatment_values,
180180
control_value=estimator.control_values,
181181
adjustment_set=estimator.adjustment_set,
182-
ate=ate,
182+
test_value=TestValue("ate", ate),
183183
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184184
confidence_intervals=confidence_intervals,
185185
)
Lines changed: 18 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Union
3-
2+
from causal_testing.testing.causal_test_result import CausalTestResult
43
import numpy as np
54

6-
from causal_testing.specification.variable import Variable
7-
8-
9-
class CausalTestResult:
10-
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11-
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12-
confidence intervals."""
13-
14-
def __init__(
15-
self,
16-
treatment: tuple,
17-
outcome: tuple,
18-
treatment_value: Union[int, float, str],
19-
control_value: Union[int, float, str],
20-
adjustment_set: set,
21-
ate: float,
22-
confidence_intervals: [float, float] = None,
23-
effect_modifier_configuration: {Variable: Any} = None,
24-
):
25-
self.treatment = treatment
26-
self.outcome = outcome
27-
self.treatment_value = treatment_value
28-
self.control_value = control_value
29-
if adjustment_set:
30-
self.adjustment_set = adjustment_set
31-
else:
32-
self.adjustment_set = set()
33-
self.ate = ate
34-
self.confidence_intervals = confidence_intervals
35-
36-
if effect_modifier_configuration is not None:
37-
self.effect_modifier_configuration = effect_modifier_configuration
38-
else:
39-
self.effect_modifier_configuration = {}
40-
41-
def __str__(self):
42-
base_str = (
43-
f"Causal Test Result\n==============\n"
44-
f"Treatment: {self.treatment[0]}\n"
45-
f"Control value: {self.control_value}\n"
46-
f"Treatment value: {self.treatment_value}\n"
47-
f"Outcome: {self.outcome[0]}\n"
48-
f"Adjustment set: {self.adjustment_set}\n"
49-
f"ATE: {self.ate}\n"
50-
)
51-
confidence_str = ""
52-
if self.confidence_intervals:
53-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
54-
return base_str + confidence_str
55-
56-
def to_dict(self):
57-
base_dict = {
58-
"treatment": self.treatment[0],
59-
"control_value": self.control_value,
60-
"treatment_value": self.treatment_value,
61-
"outcome": self.outcome[0],
62-
"adjustment_set": self.adjustment_set,
63-
"ate": self.ate,
64-
}
65-
if self.confidence_intervals:
66-
base_dict["ci_low"] = min(self.confidence_intervals)
67-
base_dict["ci_high"] = max(self.confidence_intervals)
68-
return base_dict
69-
70-
def ci_low(self):
71-
"""Return the lower bracket of the confidence intervals."""
72-
if not self.confidence_intervals:
73-
return None
74-
return min(self.confidence_intervals)
75-
76-
def ci_high(self):
77-
"""Return the higher bracket of the confidence intervals."""
78-
if not self.confidence_intervals:
79-
return None
80-
return max(self.confidence_intervals)
81-
82-
def summary(self):
83-
"""Summarise the causal test result as an intuitive sentence."""
84-
print(
85-
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
86-
f"{self.treatment[0]}' = {self.treatment_value} is {self.ate} (95% confidence intervals: "
87-
f"{self.confidence_intervals})."
88-
)
89-
905

916
class CausalTestOutcome(ABC):
927
"""An abstract class representing an expected causal effect."""
@@ -110,7 +25,7 @@ def __init__(self, value: float, tolerance: float = None):
11025
self.tolerance = tolerance
11126

11227
def apply(self, res: CausalTestResult) -> bool:
113-
return np.isclose(res.ate, self.value, atol=self.tolerance)
28+
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
11429

11530
def __str__(self):
11631
return f"ExactValue: {self.value}±{self.tolerance}"
@@ -121,22 +36,31 @@ class Positive(CausalTestOutcome):
12136

12237
def apply(self, res: CausalTestResult) -> bool:
12338
# TODO: confidence intervals?
124-
return res.ate > 0
39+
if res.test_value.type == "ate":
40+
return res.test_value.value > 0
41+
elif res.test_value.type == "risk_ratio":
42+
return res.test_value.value > 1
12543

12644

12745
class Negative(CausalTestOutcome):
12846
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
12947

13048
def apply(self, res: CausalTestResult) -> bool:
13149
# TODO: confidence intervals?
132-
return res.ate < 0
50+
if res.test_value.type == "ate":
51+
return res.test_value.value < 0
52+
elif res.test_value.type == "risk_ratio":
53+
return res.test_value.value < 1
13354

13455

13556
class SomeEffect(CausalTestOutcome):
13657
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
13758

13859
def apply(self, res: CausalTestResult) -> bool:
139-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
60+
if res.test_value.type == "ate":
61+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
62+
elif res.test_value.type == "risk_ratio":
63+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
14064

14165
def __str__(self):
14266
return "Changed"
@@ -146,7 +70,10 @@ class NoEffect(CausalTestOutcome):
14670
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
14771

14872
def apply(self, res: CausalTestResult) -> bool:
149-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.ate) < 1e-10)
73+
if res.test_value.type == "ate":
74+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
75+
elif res.test_value.type == "risk_ratio":
76+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
15077

15178
def __str__(self):
15279
return "Unchanged"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Union
2+
from dataclasses import dataclass
3+
4+
from causal_testing.specification.variable import Variable
5+
6+
7+
@dataclass
8+
class TestValue:
9+
"""A dataclass to hold both the type and value of a causal test result"""
10+
11+
type: str
12+
value: float
13+
14+
15+
class CausalTestResult:
16+
"""A container to hold the results of a causal test case. Every causal test case provides a point estimate of
17+
the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
18+
confidence intervals."""
19+
20+
def __init__(
21+
self,
22+
treatment: tuple,
23+
outcome: tuple,
24+
treatment_value: Union[int, float, str],
25+
control_value: Union[int, float, str],
26+
adjustment_set: set,
27+
test_value: TestValue,
28+
confidence_intervals: [float, float] = None,
29+
effect_modifier_configuration: {Variable: Any} = None,
30+
):
31+
self.treatment = treatment
32+
self.outcome = outcome
33+
self.treatment_value = treatment_value
34+
self.control_value = control_value
35+
if adjustment_set:
36+
self.adjustment_set = adjustment_set
37+
else:
38+
self.adjustment_set = set()
39+
self.test_value = test_value
40+
self.confidence_intervals = confidence_intervals
41+
42+
if effect_modifier_configuration is not None:
43+
self.effect_modifier_configuration = effect_modifier_configuration
44+
else:
45+
self.effect_modifier_configuration = {}
46+
47+
def __str__(self):
48+
base_str = (
49+
f"Causal Test Result\n==============\n"
50+
f"Treatment: {self.treatment[0]}\n"
51+
f"Control value: {self.control_value}\n"
52+
f"Treatment value: {self.treatment_value}\n"
53+
f"Outcome: {self.outcome[0]}\n"
54+
f"Adjustment set: {self.adjustment_set}\n"
55+
f"{self.test_value.type}: {self.test_value.value}\n"
56+
)
57+
confidence_str = ""
58+
if self.confidence_intervals:
59+
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
60+
return base_str + confidence_str
61+
62+
def to_dict(self):
63+
base_dict = {
64+
"treatment": self.treatment[0],
65+
"control_value": self.control_value,
66+
"treatment_value": self.treatment_value,
67+
"outcome": self.outcome[0],
68+
"adjustment_set": self.adjustment_set,
69+
"test_value": self.test_value,
70+
}
71+
if self.confidence_intervals:
72+
base_dict["ci_low"] = min(self.confidence_intervals)
73+
base_dict["ci_high"] = max(self.confidence_intervals)
74+
return base_dict
75+
76+
def ci_low(self):
77+
"""Return the lower bracket of the confidence intervals."""
78+
if not self.confidence_intervals:
79+
return None
80+
return min(self.confidence_intervals)
81+
82+
def ci_high(self):
83+
"""Return the higher bracket of the confidence intervals."""
84+
if not self.confidence_intervals:
85+
return None
86+
return max(self.confidence_intervals)
87+
88+
def summary(self):
89+
"""Summarise the causal test result as an intuitive sentence."""
90+
print(
91+
f"The causal effect of changing {self.treatment[0]} = {self.control_value} to "
92+
f"{self.treatment[0]}' = {self.treatment_value} is {self.test_value.value} (95% confidence intervals: "
93+
f"{self.confidence_intervals})."
94+
)

0 commit comments

Comments
 (0)