5
5
import json
6
6
import logging
7
7
8
+ from collections .abc import Iterable , Mapping
8
9
from dataclasses import dataclass
9
10
from pathlib import Path
10
11
from statistics import StatisticsError
@@ -154,10 +155,11 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
154
155
155
156
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
156
157
failures = 0
158
+ test ["estimator" ] = estimators [test ["estimator" ]]
159
+ if "formula" in test :
160
+ self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
157
161
for concrete_test in concrete_tests :
158
- failed = self ._execute_test_case (
159
- concrete_test , estimators [test ["estimator" ]], f_flag , test .get ("conditions" , [])
160
- )
162
+ failed = self ._execute_test_case (concrete_test , test , f_flag )
161
163
if failed :
162
164
failures += 1
163
165
return failures
@@ -178,19 +180,26 @@ def _populate_metas(self):
178
180
for meta in self .scenario .variables_of_type (Meta ):
179
181
meta .populate (self .data )
180
182
181
- def _execute_test_case (
182
- self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool , conditions : list [str ]
183
- ) -> bool :
183
+ def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool ) -> bool :
184
184
"""Executes a singular test case, prints the results and returns the test case result
185
185
:param causal_test_case: The concrete test case to be executed
186
+ :param test: Single JSON test definition stored in a mapping (dict)
186
187
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
187
188
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
188
189
failed.
189
190
:rtype: bool
190
191
"""
191
192
failed = False
192
193
193
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator , conditions )
194
+ for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
195
+ if not var .distribution :
196
+ fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
197
+ fitter .fit ()
198
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
199
+ var .distribution = getattr (scipy .stats , dist )(** params )
200
+ self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
201
+
202
+ causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
194
203
causal_test_result = causal_test_engine .execute_test (
195
204
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
196
205
)
@@ -216,11 +225,10 @@ def _execute_test_case(
216
225
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
217
226
return failed
218
227
219
- def _setup_test (
220
- self , causal_test_case : CausalTestCase , estimator : Estimator , conditions : list [str ]
221
- ) -> tuple [CausalTestEngine , Estimator ]:
228
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ]= None ) -> tuple [CausalTestEngine , Estimator ]:
222
229
"""Create the necessary inputs for a single test case
223
230
:param causal_test_case: The concrete test case to be executed
231
+ :param test: Single JSON test definition stored in a mapping (dict)
224
232
:returns:
225
233
- causal_test_engine - Test Engine instance for the test being run
226
234
- estimation_model - Estimator instance for the test being run
@@ -234,27 +242,21 @@ def _setup_test(
234
242
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
235
243
treatment_var = causal_test_case .treatment_variable
236
244
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
237
- estimation_model = estimator (
238
- treatment = treatment_var .name ,
239
- treatment_value = causal_test_case .treatment_value ,
240
- control_value = causal_test_case .control_value ,
241
- adjustment_set = minimal_adjustment_set ,
242
- outcome = causal_test_case .outcome_variable .name ,
243
- df = causal_test_engine .scenario_execution_data_df ,
244
- effect_modifiers = causal_test_case .effect_modifier_configuration ,
245
- )
246
-
247
- self .add_modelling_assumptions (estimation_model )
248
-
245
+ estimator_kwargs = {
246
+ "treatment" : treatment_var .name ,
247
+ "treatment_value" : causal_test_case .treatment_value ,
248
+ "control_value" : causal_test_case .control_value ,
249
+ "adjustment_set" : minimal_adjustment_set ,
250
+ "outcome" : causal_test_case .outcome_variable .name ,
251
+ "df" : causal_test_engine .scenario_execution_data_df ,
252
+ "effect_modifiers" : causal_test_case .effect_modifier_configuration ,
253
+ }
254
+ if "formula" in test :
255
+ estimator_kwargs ["formula" ] = test ["formula" ]
256
+
257
+ estimation_model = test ["estimator" ](** estimator_kwargs )
249
258
return causal_test_engine , estimation_model
250
259
251
- def add_modelling_assumptions (self , estimation_model : Estimator ): # pylint: disable=unused-argument
252
- """Optional abstract method where user functionality can be written to determine what assumptions are required
253
- for specific test cases
254
- :param estimation_model: estimator model instance for the current running test.
255
- """
256
- return
257
-
258
260
def _append_to_file (self , line : str , log_level : int = None ):
259
261
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
260
262
logging level.
@@ -263,9 +265,7 @@ def _append_to_file(self, line: str, log_level: int = None):
263
265
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
264
266
"""
265
267
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
266
- f .write (
267
- line + "\n " ,
268
- )
268
+ f .write (line )
269
269
if log_level :
270
270
logger .log (level = log_level , msg = line )
271
271
0 commit comments