5
5
import json
6
6
import logging
7
7
8
+ from collections .abc import Iterable , Mapping
8
9
from dataclasses import dataclass
9
10
from pathlib import Path
10
11
from statistics import StatisticsError
@@ -118,8 +119,11 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
118
119
119
120
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
120
121
failures = 0
122
+ test ["estimator" ] = estimators [test ["estimator" ]]
123
+ if "formula" in test :
124
+ self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
121
125
for concrete_test in concrete_tests :
122
- failed = self ._execute_test_case (concrete_test , estimators [ test [ "estimator" ]] , f_flag )
126
+ failed = self ._execute_test_case (concrete_test , test , f_flag )
123
127
if failed :
124
128
failures += 1
125
129
return failures
@@ -147,17 +151,18 @@ def _populate_metas(self):
147
151
var .distribution = getattr (scipy .stats , dist )(** params )
148
152
self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
149
153
150
- def _execute_test_case (self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool ) -> bool :
154
+ def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [ Mapping ] , f_flag : bool ) -> bool :
151
155
"""Executes a singular test case, prints the results and returns the test case result
152
156
:param causal_test_case: The concrete test case to be executed
157
+ :param test: Single JSON test definition stored in a mapping (dict)
153
158
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
154
159
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
155
160
failed.
156
161
:rtype: bool
157
162
"""
158
163
failed = False
159
164
160
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator )
165
+ causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
161
166
causal_test_result = causal_test_engine .execute_test (
162
167
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
163
168
)
@@ -183,9 +188,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
183
188
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
184
189
return failed
185
190
186
- def _setup_test (self , causal_test_case : CausalTestCase , estimator : Estimator ) -> tuple [CausalTestEngine , Estimator ]:
191
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> tuple [CausalTestEngine , Estimator ]:
187
192
"""Create the necessary inputs for a single test case
188
193
:param causal_test_case: The concrete test case to be executed
194
+ :param test: Single JSON test definition stored in a mapping (dict)
189
195
:returns:
190
196
- causal_test_engine - Test Engine instance for the test being run
191
197
- estimation_model - Estimator instance for the test being run
@@ -197,27 +203,21 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
197
203
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
198
204
treatment_var = causal_test_case .treatment_variable
199
205
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
200
- estimation_model = estimator (
201
- treatment = treatment_var .name ,
202
- treatment_value = causal_test_case .treatment_value ,
203
- control_value = causal_test_case .control_value ,
204
- adjustment_set = minimal_adjustment_set ,
205
- outcome = causal_test_case .outcome_variable .name ,
206
- df = causal_test_engine .scenario_execution_data_df ,
207
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
208
- )
209
-
210
- self .add_modelling_assumptions (estimation_model )
211
-
206
+ estimator_kwargs = {
207
+ "treatment" : treatment_var .name ,
208
+ "treatment_value" : causal_test_case .treatment_value ,
209
+ "control_value" : causal_test_case .control_value ,
210
+ "adjustment_set" : minimal_adjustment_set ,
211
+ "outcome" : causal_test_case .outcome_variable .name ,
212
+ "df" : causal_test_engine .scenario_execution_data_df ,
213
+ "effect_modifiers" : causal_test_case .effect_modifier_configuration ,
214
+ }
215
+ if "formula" in test :
216
+ estimator_kwargs ["formula" ] = test ["formula" ]
217
+
218
+ estimation_model = test ["estimator" ](** estimator_kwargs )
212
219
return causal_test_engine , estimation_model
213
220
214
- def add_modelling_assumptions (self , estimation_model : Estimator ): # pylint: disable=unused-argument
215
- """Optional abstract method where user functionality can be written to determine what assumptions are required
216
- for specific test cases
217
- :param estimation_model: estimator model instance for the current running test.
218
- """
219
- return
220
-
221
221
def _append_to_file (self , line : str , log_level : int = None ):
222
222
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
223
223
logging level.
@@ -226,9 +226,7 @@ def _append_to_file(self, line: str, log_level: int = None):
226
226
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
227
227
"""
228
228
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
229
- f .write (
230
- line + "\n " ,
231
- )
229
+ f .write (line )
232
230
if log_level :
233
231
logger .log (level = log_level , msg = line )
234
232
0 commit comments