1
1
from abc import ABC , abstractmethod
2
- from typing import Any , Union
3
-
2
+ from causal_test_result import CausalTestResult
4
3
import numpy as np
5
4
6
- from causal_testing .specification .variable import Variable
7
-
8
-
9
- class CausalTestResult :
10
- """A container to hold the results of a causal test case. Every causal test case provides a point estimate of
11
- the ATE, given a particular treatment, outcome, and adjustment set. Some but not all estimators can provide
12
- confidence intervals."""
13
-
14
- def __init__ (
15
- self ,
16
- treatment : tuple ,
17
- outcome : tuple ,
18
- treatment_value : Union [int , float , str ],
19
- control_value : Union [int , float , str ],
20
- adjustment_set : set ,
21
- ate : float ,
22
- confidence_intervals : [float , float ] = None ,
23
- effect_modifier_configuration : {Variable : Any } = None ,
24
- ):
25
- self .treatment = treatment
26
- self .outcome = outcome
27
- self .treatment_value = treatment_value
28
- self .control_value = control_value
29
- if adjustment_set :
30
- self .adjustment_set = adjustment_set
31
- else :
32
- self .adjustment_set = set ()
33
- self .ate = ate
34
- self .confidence_intervals = confidence_intervals
35
-
36
- if effect_modifier_configuration is not None :
37
- self .effect_modifier_configuration = effect_modifier_configuration
38
- else :
39
- self .effect_modifier_configuration = {}
40
-
41
- def __str__ (self ):
42
- base_str = (
43
- f"Causal Test Result\n ==============\n "
44
- f"Treatment: { self .treatment [0 ]} \n "
45
- f"Control value: { self .control_value } \n "
46
- f"Treatment value: { self .treatment_value } \n "
47
- f"Outcome: { self .outcome [0 ]} \n "
48
- f"Adjustment set: { self .adjustment_set } \n "
49
- f"ATE: { self .ate } \n "
50
- )
51
- confidence_str = ""
52
- if self .confidence_intervals :
53
- confidence_str += f"Confidence intervals: { self .confidence_intervals } \n "
54
- return base_str + confidence_str
55
-
56
- def to_dict (self ):
57
- base_dict = {
58
- "treatment" : self .treatment [0 ],
59
- "control_value" : self .control_value ,
60
- "treatment_value" : self .treatment_value ,
61
- "outcome" : self .outcome [0 ],
62
- "adjustment_set" : self .adjustment_set ,
63
- "ate" : self .ate ,
64
- }
65
- if self .confidence_intervals :
66
- base_dict ["ci_low" ] = min (self .confidence_intervals )
67
- base_dict ["ci_high" ] = max (self .confidence_intervals )
68
- return base_dict
69
-
70
- def ci_low (self ):
71
- """Return the lower bracket of the confidence intervals."""
72
- if not self .confidence_intervals :
73
- return None
74
- return min (self .confidence_intervals )
75
-
76
- def ci_high (self ):
77
- """Return the higher bracket of the confidence intervals."""
78
- if not self .confidence_intervals :
79
- return None
80
- return max (self .confidence_intervals )
81
-
82
- def summary (self ):
83
- """Summarise the causal test result as an intuitive sentence."""
84
- print (
85
- f"The causal effect of changing { self .treatment [0 ]} = { self .control_value } to "
86
- f"{ self .treatment [0 ]} ' = { self .treatment_value } is { self .ate } (95% confidence intervals: "
87
- f"{ self .confidence_intervals } )."
88
- )
89
-
90
-
91
5
class CausalTestOutcome (ABC ):
92
6
"""An abstract class representing an expected causal effect."""
93
7
@@ -110,7 +24,7 @@ def __init__(self, value: float, tolerance: float = None):
110
24
self .tolerance = tolerance
111
25
112
26
def apply (self , res : CausalTestResult ) -> bool :
113
- return np .isclose (res .ate , self .value , atol = self .tolerance )
27
+ return np .isclose (res .test_value . value , self .value , atol = self .tolerance )
114
28
115
29
def __str__ (self ):
116
30
return f"ExactValue: { self .value } ±{ self .tolerance } "
@@ -121,22 +35,31 @@ class Positive(CausalTestOutcome):
121
35
122
36
def apply (self , res : CausalTestResult ) -> bool :
123
37
# TODO: confidence intervals?
124
- return res .ate > 0
38
+ if res .test_value .type == "ate" :
39
+ return res .test_value .value > 0
40
+ elif res .test_value .type == "risk_ratio" :
41
+ return res .test_value .value > 1
125
42
126
43
127
44
class Negative (CausalTestOutcome ):
128
45
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
129
46
130
47
def apply (self , res : CausalTestResult ) -> bool :
131
48
# TODO: confidence intervals?
132
- return res .ate < 0
49
+ if res .test_value .type == "ate" :
50
+ return res .test_value .value < 0
51
+ elif res .test_value .type == "risk_ratio" :
52
+ return res .test_value .value < 1
133
53
134
54
135
55
class SomeEffect (CausalTestOutcome ):
136
56
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
137
57
138
58
def apply (self , res : CausalTestResult ) -> bool :
139
- return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
59
+ if res .test_value .type == "ate" :
60
+ return (0 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 0 )
61
+ elif res .test_value .type == "risk_ratio" :
62
+ return (1 < res .ci_low () < res .ci_high ()) or (res .ci_low () < res .ci_high () < 1 )
140
63
141
64
def __str__ (self ):
142
65
return "Changed"
@@ -146,7 +69,10 @@ class NoEffect(CausalTestOutcome):
146
69
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
147
70
148
71
def apply (self , res : CausalTestResult ) -> bool :
149
- return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .ate ) < 1e-10 )
72
+ if res .test_value .type == "ate" :
73
+ return (res .ci_low () < 0 < res .ci_high ()) or (abs (res .ate ) < 1e-10 )
74
+ elif res .test_value .type == "risk_ratio" :
75
+ return (res .ci_low () < 1 < res .ci_high ()) or np .isclose (res .test_value .value , 1.0 , atol = 1e-10 )
150
76
151
77
def __str__ (self ):
152
78
return "Unchanged"
0 commit comments