9
9
from causal_testing .testing .base_test_case import BaseTestCase
10
10
from causal_testing .testing .estimators import Estimator
11
11
from causal_testing .testing .causal_test_result import CausalTestResult
12
- from causal_testing .data_collection .data_collector import DataCollector
12
+ from causal_testing .data_collection .data_collector import ObservationalDataCollector
13
+ from causal_testing .specification .causal_dag import CausalDAG
14
+ from causal_testing .specification .scenario import Scenario
13
15
16
+ from causal_testing .specification .causal_specification import CausalSpecification
14
17
logger = logging .getLogger (__name__ )
15
18
16
19
@@ -28,15 +31,15 @@ class CausalTestCase:
28
31
"""
29
32
30
33
def __init__ (
31
- # pylint: disable=too-many-arguments
32
- self ,
33
- base_test_case : BaseTestCase ,
34
- expected_causal_effect : CausalTestOutcome ,
35
- control_value : Any = None ,
36
- treatment_value : Any = None ,
37
- estimate_type : str = "ate" ,
38
- estimate_params : dict = None ,
39
- effect_modifier_configuration : dict [Variable :Any ] = None ,
34
+ # pylint: disable=too-many-arguments
35
+ self ,
36
+ base_test_case : BaseTestCase ,
37
+ expected_causal_effect : CausalTestOutcome ,
38
+ control_value : Any = None ,
39
+ treatment_value : Any = None ,
40
+ estimate_type : str = "ate" ,
41
+ estimate_params : dict = None ,
42
+ effect_modifier_configuration : dict [Variable :Any ] = None ,
40
43
):
41
44
"""
42
45
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
@@ -78,35 +81,81 @@ def get_treatment_value(self):
78
81
"""Return the treatment value of the treatment variable in this causal test case."""
79
82
return self .treatment_value
80
83
81
- def execute_test (self , estimator : type (Estimator ), dataframe : pd . DataFrame ) -> CausalTestResult :
84
+ def execute_test (self , estimator : type (Estimator ), data_collector : ObservationalDataCollector , causal_specification : CausalSpecification ) -> CausalTestResult :
82
85
"""Execute a causal test case and return the causal test result.
83
86
84
87
:param estimator: A reference to an Estimator class.
85
88
:param causal_test_case: The CausalTestCase object to be tested
86
89
:return causal_test_result: A CausalTestResult for the executed causal test case.
87
90
"""
88
- if self . scenario_execution_data_df . empty :
89
- raise ValueError ( "No data has been loaded. Please call load_data prior to executing a causal test case." )
91
+ if not data_collector . data_checked :
92
+ data_collector . collect_data ( )
90
93
if estimator .df is None :
91
- estimator .df = dataframe
94
+ estimator .df = data_collector . data
92
95
treatment_variable = self .treatment_variable
93
96
treatments = treatment_variable .name
94
97
outcome_variable = self .outcome_variable
95
98
96
99
logger .info ("treatments: %s" , treatments )
97
100
logger .info ("outcomes: %s" , outcome_variable )
98
- minimal_adjustment_set = self .causal_dag .identification (BaseTestCase (treatment_variable , outcome_variable ))
101
+ minimal_adjustment_set = causal_specification .causal_dag .identification (BaseTestCase (treatment_variable , outcome_variable ))
99
102
minimal_adjustment_set = minimal_adjustment_set - set (treatment_variable .name )
100
103
minimal_adjustment_set = minimal_adjustment_set - set (outcome_variable .name )
101
104
102
105
variables_for_positivity = list (minimal_adjustment_set ) + [treatment_variable .name ] + [outcome_variable .name ]
103
106
104
- if self ._check_positivity_violation (variables_for_positivity ):
107
+ if self ._check_positivity_violation (variables_for_positivity , causal_specification . scenario , data_collector . data ):
105
108
raise ValueError ("POSITIVITY VIOLATION -- Cannot proceed." )
106
109
107
110
causal_test_result = self ._return_causal_test_results (estimator )
108
111
return causal_test_result
109
112
113
+ def _return_causal_test_results (self , estimator , causal_test_case ):
114
+ """Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
115
+
116
+ :param estimator: An Estimator class object
117
+ :param causal_test_case: The concrete test case to be executed
118
+ :return: a CausalTestResult object containing the confidence intervals
119
+ """
120
+ if not hasattr (estimator , f"estimate_{ causal_test_case .estimate_type } " ):
121
+ raise AttributeError (f"{ estimator .__class__ } has no { causal_test_case .estimate_type } method." )
122
+ estimate_effect = getattr (estimator , f"estimate_{ causal_test_case .estimate_type } " )
123
+ effect , confidence_intervals = estimate_effect (** causal_test_case .estimate_params )
124
+ causal_test_result = CausalTestResult (
125
+ estimator = estimator ,
126
+ test_value = TestValue (causal_test_case .estimate_type , effect ),
127
+ effect_modifier_configuration = causal_test_case .effect_modifier_configuration ,
128
+ confidence_intervals = confidence_intervals ,
129
+ )
130
+
131
+ return causal_test_result
132
+
133
+ def _check_positivity_violation (self , variables_list , scenario : Scenario , df ):
134
+ """Check whether the dataframe has a positivity violation relative to the specified variables list.
135
+
136
+ A positivity violation occurs when there is a stratum of the dataframe which does not have any data. Put simply,
137
+ if we split the dataframe into covariate sub-groups, each sub-group must contain both a treated and untreated
138
+ individual. If a positivity violation occurs, causal inference is still possible using a properly specified
139
+ parametric estimator. Therefore, we should not throw an exception upon violation but raise a warning instead.
140
+
141
+ :param variables_list: The list of variables for which positivity must be satisfied.
142
+ :return: True if positivity is violated, False otherwise.
143
+ """
144
+ if not (set (variables_list ) - {x .name for x in scenario .hidden_variables ()}).issubset (
145
+ df .columns
146
+ ):
147
+ missing_variables = set (variables_list ) - set (df .columns )
148
+ logger .warning (
149
+ "Positivity violation: missing data for variables %s.\n "
150
+ "Causal inference is only valid if a well-specified parametric model is used.\n "
151
+ "Alternatively, consider restricting analysis to executions without the variables:"
152
+ "." ,
153
+ missing_variables ,
154
+ )
155
+ return True
156
+
157
+ return False
158
+
110
159
def __str__ (self ):
111
160
treatment_config = {self .treatment_variable .name : self .treatment_value }
112
161
control_config = {self .treatment_variable .name : self .control_value }
0 commit comments