1
1
from abc import ABC , abstractmethod
2
- from typing import Any , Union
3
-
2
+ from causal_testing .testing .causal_test_result import CausalTestResult
4
3
import numpy as np
5
4
6
- from causal_testing .specification .variable import Variable
7
-
8
-
9
- class CausalTestResult :
10
- """A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11
- the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12
- confidence intervals."""
13
-
14
- def __init__ (
15
- self ,
16
- treatment : tuple ,
17
- outcome : tuple ,
18
- treatment_value : Union [int , float , str ],
19
- control_value : Union [int , float , str ],
20
- adjustment_set : set ,
21
- ate : float ,
22
- confidence_intervals : [float , float ] = None ,
23
- effect_modifier_configuration : {Variable : Any } = None ,
24
- ):
25
- self .treatment = treatment
26
- self .outcome = outcome
27
- self .treatment_value = treatment_value
28
- self .control_value = control_value
29
- if adjustment_set :
30
- self .adjustment_set = adjustment_set
31
- else :
32
- self .adjustment_set = set ()
33
- self .ate = ate
34
- self .confidence_intervals = confidence_intervals
35
-
36
- if effect_modifier_configuration is not None :
37
- self .effect_modifier_configuration = effect_modifier_configuration
38
- else :
39
- self .effect_modifier_configuration = {}
40
-
41
- def __str__ (self ):
42
- base_str = (
43
- f"Causal Test Result\n ==============\n "
44
- f"Treatment: { self .treatment [0 ]} \n "
45
- f"Control value: { self .control_value } \n "
46
- f"Treatment value: { self .treatment_value } \n "
47
- f"Outcome: { self .outcome [0 ]} \n "
48
- f"Adjustment set: { self .adjustment_set } \n "
49
- f"ATE: { self .ate } \n "
50
- )
51
- confidence_str = ""
52
- if self .confidence_intervals :
53
- confidence_str += f"Confidence intervals: { self .confidence_intervals } \n "
54
- return base_str + confidence_str
55
-
56
- def to_dict (self ):
57
- base_dict = {
58
- "treatment" : self .treatment [0 ],
59
- "control_value" : self .control_value ,
60
- "treatment_value" : self .treatment_value ,
61
- "outcome" : self .outcome [0 ],
62
- "adjustment_set" : self .adjustment_set ,
63
- "ate" : self .ate ,
64
- }
65
- if self .confidence_intervals :
66
- base_dict ["ci_low" ] = min (self .confidence_intervals )
67
- base_dict ["ci_high" ] = max (self .confidence_intervals )
68
- return base_dict
69
-
70
- def ci_low (self ):
71
- """Return the lower bracket of the confidence intervals."""
72
- if not self .confidence_intervals :
73
- return None
74
- return min (self .confidence_intervals )
75
-
76
- def ci_high (self ):
77
- """Return the higher bracket of the confidence intervals."""
78
- if not self .confidence_intervals :
79
- return None
80
- return max (self .confidence_intervals )
81
-
82
- def summary (self ):
83
- """Summarise the causal test result as an intuitive sentence."""
84
- print (
85
- f"The causal effect of changing { self .treatment [0 ]} = { self .control_value } to "
86
- f"{ self .treatment [0 ]} ' = { self .treatment_value } is { self .ate } (95% confidence intervals: "
87
- f"{ self .confidence_intervals } )."
88
- )
89
-
90
5
91
6
class CausalTestOutcome (ABC ):
92
7
"""An abstract class representing an expected causal effect."""
@@ -110,7 +25,7 @@ def __init__(self, value: float, tolerance: float = None):
110
25
self .tolerance = tolerance
111
26
112
27
def apply (self , res : CausalTestResult ) -> bool :
113
- return np .isclose (res .ate , self .value , atol = self .tolerance )
28
+ return np .isclose (res .test_value . value , self .value , atol = self .tolerance )
114
29
115
30
def __str__ (self ):
116
31
return f"ExactValue: { self .value } ±{ self .tolerance } "
@@ -121,22 +36,31 @@ class Positive(CausalTestOutcome):
121
36
122
37
def apply (self , res : CausalTestResult ) -> bool :
123
38
# TODO: confidence intervals?
124
- return res .ate > 0
39
+ if res .test_value .type == "ate" :
40
+ return res .test_value .value > 0
41
+ elif res .test_value .type == "risk_ratio" :
42
+ return res .test_value .value > 1
125
43
126
44
127
45
class Negative (CausalTestOutcome ):
128
46
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
129
47
130
48
def apply (self , res : CausalTestResult ) -> bool :
131
49
# TODO: confidence intervals?
132
- return res .ate < 0
50
+ if res .test_value .type == "ate" :
51
+ return res .test_value .value < 0
52
+ elif res .test_value .type == "risk_ratio" :
53
+ return res .test_value .value < 1
133
54
134
55
135
56
class SomeEffect (CausalTestOutcome ):
136
57
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
137
58
138
59
def apply (self , res : CausalTestResult ) -> bool :
139
- return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
60
+ if res .test_value .type == "ate" :
61
+ return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
62
+ elif res .test_value .type == "risk_ratio" :
63
+ return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
140
64
141
65
def __str__ (self ):
142
66
return "Changed"
@@ -146,7 +70,10 @@ class NoEffect(CausalTestOutcome):
146
70
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
147
71
148
72
def apply (self , res : CausalTestResult ) -> bool :
149
- return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .ate ) < 1e-10 )
73
+ if res .test_value .type == "ate" :
74
+ return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .test_value .value ) < 1e-10 )
75
+ elif res .test_value .type == "risk_ratio" :
76
+ return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = 1e-10 )
150
77
151
78
def __str__ (self ):
152
79
return "Unchanged"
0 commit comments