5
5
import json
6
6
import logging
7
7
8
+ from collections .abc import Iterable , Mapping
8
9
from dataclasses import dataclass
9
10
from pathlib import Path
10
11
@@ -112,8 +113,9 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
112
113
113
114
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
114
115
failures = 0
116
+ test ["estimator" ] = estimators [test ["estimator" ]]
115
117
for concrete_test in concrete_tests :
116
- failed = self ._execute_test_case (concrete_test , estimators [ test [ "estimator" ]] , f_flag )
118
+ failed = self ._execute_test_case (concrete_test , test , f_flag )
117
119
if failed :
118
120
failures += 1
119
121
return failures
@@ -141,7 +143,7 @@ def _populate_metas(self):
141
143
var .distribution = getattr (scipy .stats , dist )(** params )
142
144
logger .info (var .name + f" { dist } ({ params } )" )
143
145
144
- def _execute_test_case (self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool ) -> bool :
146
+ def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [ Mapping ] , f_flag : bool ) -> bool :
145
147
"""Executes a singular test case, prints the results and returns the test case result
146
148
:param causal_test_case: The concrete test case to be executed
147
149
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
@@ -151,7 +153,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
151
153
"""
152
154
failed = False
153
155
154
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator )
156
+ causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
155
157
causal_test_result = causal_test_engine .execute_test (
156
158
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
157
159
)
@@ -176,7 +178,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
176
178
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
177
179
return failed
178
180
179
- def _setup_test (self , causal_test_case : CausalTestCase , estimator : Estimator ) -> tuple [CausalTestEngine , Estimator ]:
181
+ def _setup_test (self , causal_test_case : CausalTestCase , test ) -> tuple [CausalTestEngine , Estimator ]:
180
182
"""Create the necessary inputs for a single test case
181
183
:param causal_test_case: The concrete test case to be executed
182
184
:returns:
@@ -190,17 +192,27 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
190
192
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
191
193
treatment_var = causal_test_case .treatment_variable
192
194
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
193
- estimation_model = estimator (
194
- treatment = treatment_var .name ,
195
- treatment_value = causal_test_case .treatment_value ,
196
- control_value = causal_test_case .control_value ,
197
- adjustment_set = minimal_adjustment_set ,
198
- outcome = causal_test_case .outcome_variable .name ,
199
- df = causal_test_engine .scenario_execution_data_df ,
200
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
201
- )
202
-
203
- self .add_modelling_assumptions (estimation_model )
195
+ if "formula" in test :
196
+ estimation_model = test ["estimator" ](
197
+ treatment = treatment_var .name ,
198
+ treatment_value = causal_test_case .treatment_value ,
199
+ control_value = causal_test_case .control_value ,
200
+ adjustment_set = minimal_adjustment_set ,
201
+ outcome = causal_test_case .outcome_variable .name ,
202
+ df = causal_test_engine .scenario_execution_data_df ,
203
+ effect_modifiers = causal_test_case .effect_modifier_configuration ,
204
+ formula = test ["formula" ]
205
+ )
206
+ else :
207
+ estimation_model = test ["estimator" ](
208
+ treatment = treatment_var .name ,
209
+ treatment_value = causal_test_case .treatment_value ,
210
+ control_value = causal_test_case .control_value ,
211
+ adjustment_set = minimal_adjustment_set ,
212
+ outcome = causal_test_case .outcome_variable .name ,
213
+ df = causal_test_engine .scenario_execution_data_df ,
214
+ effect_modifiers = causal_test_case .effect_modifier_configuration ,
215
+ )
204
216
205
217
return causal_test_engine , estimation_model
206
218
0 commit comments