5
5
import json
6
6
import logging
7
7
8
+ from collections .abc import Iterable , Mapping
8
9
from dataclasses import dataclass
9
10
from pathlib import Path
10
11
from statistics import StatisticsError
@@ -153,8 +154,11 @@ def _create_abstract_test_case(self, test, mutates, effects):
153
154
154
155
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
155
156
failures = 0
157
+ test ["estimator" ] = estimators [test ["estimator" ]]
158
+ if "formula" in test :
159
+ self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
156
160
for concrete_test in concrete_tests :
157
- failed = self ._execute_test_case (concrete_test , estimators [ test [ "estimator" ]] , f_flag )
161
+ failed = self ._execute_test_case (concrete_test , test , f_flag )
158
162
if failed :
159
163
failures += 1
160
164
return failures
@@ -182,16 +186,17 @@ def _populate_metas(self):
182
186
var .distribution = getattr (scipy .stats , dist )(** params )
183
187
self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
184
188
185
- def _execute_test_case (self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool ) -> bool :
189
+ def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [ Mapping ] , f_flag : bool ) -> bool :
186
190
"""Executes a singular test case, prints the results and returns the test case result
187
191
:param causal_test_case: The concrete test case to be executed
192
+ :param test: Single JSON test definition stored in a mapping (dict)
188
193
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
189
194
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
190
195
failed.
191
196
:rtype: bool
192
197
"""
193
198
failed = False
194
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator )
199
+ causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
195
200
causal_test_result = causal_test_engine .execute_test (
196
201
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
197
202
)
@@ -217,9 +222,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
217
222
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
218
223
return failed
219
224
220
- def _setup_test (self , causal_test_case : CausalTestCase , estimator : Estimator ) -> tuple [CausalTestEngine , Estimator ]:
225
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> tuple [CausalTestEngine , Estimator ]:
221
226
"""Create the necessary inputs for a single test case
222
227
:param causal_test_case: The concrete test case to be executed
228
+ :param test: Single JSON test definition stored in a mapping (dict)
223
229
:returns:
224
230
- causal_test_engine - Test Engine instance for the test being run
225
231
- estimation_model - Estimator instance for the test being run
@@ -231,27 +237,21 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
231
237
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
232
238
treatment_var = causal_test_case .treatment_variable
233
239
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
234
- estimation_model = estimator (
235
- treatment = treatment_var .name ,
236
- treatment_value = causal_test_case .treatment_value ,
237
- control_value = causal_test_case .control_value ,
238
- adjustment_set = minimal_adjustment_set ,
239
- outcome = causal_test_case .outcome_variable .name ,
240
- df = causal_test_engine .scenario_execution_data_df ,
241
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
242
- )
243
-
244
- self .add_modelling_assumptions (estimation_model )
245
-
240
+ estimator_kwargs = {
241
+ "treatment" : treatment_var .name ,
242
+ "treatment_value" : causal_test_case .treatment_value ,
243
+ "control_value" : causal_test_case .control_value ,
244
+ "adjustment_set" : minimal_adjustment_set ,
245
+ "outcome" : causal_test_case .outcome_variable .name ,
246
+ "df" : causal_test_engine .scenario_execution_data_df ,
247
+ "effect_modifiers" : causal_test_case .effect_modifier_configuration ,
248
+ }
249
+ if "formula" in test :
250
+ estimator_kwargs ["formula" ] = test ["formula" ]
251
+
252
+ estimation_model = test ["estimator" ](** estimator_kwargs )
246
253
return causal_test_engine , estimation_model
247
254
248
- def add_modelling_assumptions (self , estimation_model : Estimator ): # pylint: disable=unused-argument
249
- """Optional abstract method where user functionality can be written to determine what assumptions are required
250
- for specific test cases
251
- :param estimation_model: estimator model instance for the current running test.
252
- """
253
- return
254
-
255
255
def _append_to_file (self , line : str , log_level : int = None ):
256
256
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
257
257
logging level.
@@ -261,7 +261,7 @@ def _append_to_file(self, line: str, log_level: int = None):
261
261
"""
262
262
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
263
263
f .write (
264
- line ,
264
+ line + " \n " ,
265
265
)
266
266
if log_level :
267
267
logger .log (level = log_level , msg = line )
0 commit comments