7
7
from causal_testing .testing .causal_test_case import CausalTestCase
8
8
from causal_testing .testing .causal_test_outcome import CausalTestResult
9
9
from causal_testing .testing .estimators import Estimator
10
+ from causal_testing .testing .base_causal_test import BaseCausalTest
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
@@ -50,10 +51,10 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
50
51
)
51
52
self .data_collector = data_collector
52
53
self .scenario_execution_data_df = self .data_collector .collect_data (** kwargs )
53
- self . minimal_adjustment_set = set ()
54
+
54
55
55
56
def execute_test_suite (
56
- self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
57
+ self , test_suite : dict [dict [CausalTestCase ], dict [Estimator ], str ]
57
58
) -> list [CausalTestResult ]:
58
59
"""Execute a suite of causal tests and return the results in a list"""
59
60
if self .scenario_execution_data_df .empty :
@@ -63,29 +64,37 @@ def execute_test_suite(
63
64
64
65
logger .info ("treatment: %s" , edge .treatment_variable )
65
66
logger .info ("outcome: %s" , edge .outcome_variable )
66
- minimal_adjustment_set = self .casual_dag .identification (edge )
67
- minimal_adjustment_set = minimal_adjustment_set - {v .name for v in edge .treatment_variable }
68
- minimal_adjustment_set = minimal_adjustment_set - {v .name for v in edge .outcome_variable }
69
-
70
- variables_for_positivity = list (minimal_adjustment_set ) + edge .treatment_variable + edge .outcome_variable
67
+ minimal_adjustment_set = self .causal_dag .identification (edge )
68
+ minimal_adjustment_set = minimal_adjustment_set - set (edge .treatment_variable .name )
69
+ minimal_adjustment_set = minimal_adjustment_set - set (edge .outcome_variable .name )
71
70
71
+ variables_for_positivity = list (minimal_adjustment_set ) + [edge .treatment_variable .name ] + [
72
+ edge .outcome_variable .name ]
72
73
if self ._check_positivity_violation (variables_for_positivity ):
73
74
# TODO: We should allow users to continue because positivity can be overcome with parametric models
74
75
# TODO: When we implement causal contracts, we should also note the positivity violation there
75
76
raise Exception ("POSITIVITY VIOLATION -- Cannot proceed." )
76
77
77
- for estimator in edge ["estimators" ]:
78
- if estimator .df is None :
79
- estimator .df = self .scenario_execution_data_df
80
-
81
- for test in edge ["tests" ]:
82
- logger .info ("minimal_adjustment_set: %s" , self .minimal_adjustment_set )
83
- causal_test_result = self ._return_causal_test_results (test_suite .estimate_type , estimator )
78
+ estimators = test_suite [edge ]["estimators" ]
79
+ tests = test_suite [edge ]["tests" ]
80
+ estimate_type = test_suite [edge ]["estimate_type" ]
81
+
82
+ for EstimatorClass in estimators :
83
+
84
+ for test in tests :
85
+ treatment_variable = list (test .treatment_input_configuration .keys ())[0 ]
86
+ treatment_value = list (test .treatment_input_configuration .values ())[0 ]
87
+ control_value = list (test .control_input_configuration .values ())[0 ]
88
+ estimator = EstimatorClass ((treatment_variable .name ,), treatment_value , control_value ,
89
+ minimal_adjustment_set , (test .outcome_variable .name ,))
90
+ if estimator .df is None :
91
+ estimator .df = self .scenario_execution_data_df
92
+ causal_test_result = self ._return_causal_test_results (estimate_type , estimator , test )
84
93
causal_test_results .append (causal_test_result )
85
94
return causal_test_results
86
95
87
96
def execute_test (
88
- self , estimator : Estimator , causal_test_case : CausalTestCase , estimate_type : str = "ate"
97
+ self , estimator : type ( Estimator ) , causal_test_case : CausalTestCase , estimate_type : str = "ate"
89
98
) -> CausalTestResult :
90
99
"""Execute a causal test case and return the causal test result.
91
100
@@ -108,26 +117,26 @@ def execute_test(
108
117
raise Exception ("No data has been loaded. Please call load_data prior to executing a causal test case." )
109
118
if estimator .df is None :
110
119
estimator .df = self .scenario_execution_data_df
111
- treatment_variables = list (causal_test_case .control_input_configuration )
112
- treatments = [ v .name for v in treatment_variables ]
113
- outcomes = [ v . name for v in causal_test_case .outcome_variables ]
120
+ treatment_variable = list (causal_test_case .control_input_configuration . keys ())[ 0 ]
121
+ treatments = treatment_variable .name
122
+ outcome_variable = causal_test_case .outcome_variable
114
123
115
124
logger .info ("treatments: %s" , treatments )
116
- logger .info ("outcomes: %s" , outcomes )
117
- logger .info ("minimal_adjustment_set: %s" , self .minimal_adjustment_set )
118
-
119
- minimal_adjustment_set = self .minimal_adjustment_set - {
125
+ logger .info ("outcomes: %s" , outcome_variable )
126
+ minimal_adjustment_set = self .causal_dag .identification (BaseCausalTest (treatment_variable , outcome_variable ))
127
+ minimal_adjustment_set = minimal_adjustment_set - {
120
128
v .name for v in causal_test_case .control_input_configuration
121
129
}
122
- minimal_adjustment_set = minimal_adjustment_set - { v .name for v in causal_test_case . outcome_variables }
130
+ minimal_adjustment_set = minimal_adjustment_set - set ( outcome_variable .name )
123
131
assert all (
124
132
(v .name not in minimal_adjustment_set for v in causal_test_case .control_input_configuration )
125
133
), "Treatment vars in adjustment set"
126
- assert all (
127
- ( v . name not in minimal_adjustment_set for v in causal_test_case . outcome_variables )
134
+ assert (
135
+ outcome_variable not in minimal_adjustment_set
128
136
), "Outcome vars in adjustment set"
137
+ variables_for_positivity = list (minimal_adjustment_set ) + [treatment_variable .name ] + [
138
+ outcome_variable .name ]
129
139
130
- variables_for_positivity = list (minimal_adjustment_set ) + treatments + outcomes
131
140
if self ._check_positivity_violation (variables_for_positivity ):
132
141
# TODO: We should allow users to continue because positivity can be overcome with parametric models
133
142
# TODO: When we implement causal contracts, we should also note the positivity violation there
0 commit comments