5
5
import json
6
6
import logging
7
7
8
- from collections .abc import Iterable , Mapping
8
+ from collections .abc import Mapping
9
9
from dataclasses import dataclass
10
10
from pathlib import Path
11
11
from statistics import StatisticsError
22
22
from causal_testing .specification .variable import Input , Meta , Output
23
23
from causal_testing .testing .causal_test_case import CausalTestCase
24
24
from causal_testing .testing .causal_test_result import CausalTestResult
25
- from causal_testing .testing .causal_test_engine import CausalTestEngine
26
25
from causal_testing .testing .estimators import Estimator
27
26
from causal_testing .testing .base_test_case import BaseTestCase
28
27
@@ -49,12 +48,12 @@ class JsonUtility:
49
48
def __init__ (self , output_path : str , output_overwrite : bool = False ):
50
49
self .input_paths = None
51
50
self .variables = {"inputs" : {}, "outputs" : {}, "metas" : {}}
52
- self .data = []
53
51
self .test_plan = None
54
52
self .scenario = None
55
53
self .causal_specification = None
56
54
self .output_path = Path (output_path )
57
55
self .check_file_exists (self .output_path , output_overwrite )
56
+ self .data_collector = None
58
57
59
58
def set_paths (self , json_path : str , dag_path : str , data_paths : list [str ] = None ):
60
59
"""
@@ -69,6 +68,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
69
68
70
69
def setup (self , scenario : Scenario ):
71
70
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
71
+ data = []
72
72
self .scenario = scenario
73
73
self ._get_scenario_variables ()
74
74
self .scenario .setup_treatment_variables ()
@@ -80,20 +80,21 @@ def setup(self, scenario: Scenario):
80
80
self .test_plan = json .load (f )
81
81
# Populate the data
82
82
if self .input_paths .data_paths :
83
- self . data = pd .concat ([pd .read_csv (data_file , header = 0 ) for data_file in self .input_paths .data_paths ])
84
- if len (self . data ) == 0 :
83
+ data = pd .concat ([pd .read_csv (data_file , header = 0 ) for data_file in self .input_paths .data_paths ])
84
+ if len (data ) == 0 :
85
85
raise ValueError (
86
86
"No data found, either provide a path to a file containing data or manually populate the .data "
87
87
"attribute with a dataframe before calling .setup()"
88
88
)
89
+ self .data_collector = ObservationalDataCollector (self .scenario , data )
89
90
self ._populate_metas ()
90
91
91
92
def _create_abstract_test_case (self , test , mutates , effects ):
92
93
assert len (test ["mutations" ]) == 1
93
94
treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
94
95
95
96
if not treatment_var .distribution :
96
- fitter = Fitter (self .data [treatment_var .name ], distributions = get_common_distributions ())
97
+ fitter = Fitter (self .data_collector . data [treatment_var .name ], distributions = get_common_distributions ())
97
98
fitter .fit ()
98
99
(dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
99
100
treatment_var .distribution = getattr (scipy .stats , dist )(** params )
@@ -149,6 +150,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
149
150
treatment_value = test ["treatment_value" ],
150
151
estimate_type = test ["estimate_type" ],
151
152
)
153
+
152
154
failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
153
155
154
156
msg = (
@@ -231,6 +233,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
231
233
details = []
232
234
if "formula" in test :
233
235
self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
236
+
234
237
for concrete_test in concrete_tests :
235
238
failed , result = self ._execute_test_case (concrete_test , test , f_flag )
236
239
details .append (result )
@@ -243,10 +246,10 @@ def _populate_metas(self):
243
246
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
244
247
"""
245
248
for meta in self .scenario .variables_of_type (Meta ):
246
- meta .populate (self .data )
249
+ meta .populate (self .data_collector . data )
247
250
248
251
def _execute_test_case (
249
- self , causal_test_case : CausalTestCase , test : Iterable [ Mapping ] , f_flag : bool
252
+ self , causal_test_case : CausalTestCase , test : Mapping , f_flag : bool
250
253
) -> (bool , CausalTestResult ):
251
254
"""Executes a singular test case, prints the results and returns the test case result
252
255
:param causal_test_case: The concrete test case to be executed
@@ -258,10 +261,10 @@ def _execute_test_case(
258
261
"""
259
262
failed = False
260
263
261
- causal_test_engine , estimation_model = self ._setup_test (
262
- causal_test_case , test , test ["conditions" ] if "conditions" in test else None
264
+ estimation_model = self ._setup_test (causal_test_case = causal_test_case , test = test )
265
+ causal_test_result = causal_test_case .execute_test (
266
+ estimator = estimation_model , data_collector = self .data_collector
263
267
)
264
- causal_test_result = causal_test_engine .execute_test (estimation_model , causal_test_case )
265
268
266
269
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
267
270
@@ -283,9 +286,7 @@ def _execute_test_case(
283
286
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
284
287
return failed , causal_test_result
285
288
286
- def _setup_test (
287
- self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
288
- ) -> tuple [CausalTestEngine , Estimator ]:
289
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> Estimator :
289
290
"""Create the necessary inputs for a single test case
290
291
:param causal_test_case: The concrete test case to be executed
291
292
:param test: Single JSON test definition stored in a mapping (dict)
@@ -296,12 +297,6 @@ def _setup_test(
296
297
- causal_test_engine - Test Engine instance for the test being run
297
298
- estimation_model - Estimator instance for the test being run
298
299
"""
299
-
300
- data_collector = ObservationalDataCollector (
301
- self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
302
- )
303
- causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
304
-
305
300
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
306
301
treatment_var = causal_test_case .treatment_variable
307
302
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
@@ -311,14 +306,13 @@ def _setup_test(
311
306
"control_value" : causal_test_case .control_value ,
312
307
"adjustment_set" : minimal_adjustment_set ,
313
308
"outcome" : causal_test_case .outcome_variable .name ,
314
- "df" : causal_test_engine .scenario_execution_data_df ,
315
309
"effect_modifiers" : causal_test_case .effect_modifier_configuration ,
316
310
"alpha" : test ["alpha" ] if "alpha" in test else 0.05 ,
317
311
}
318
312
if "formula" in test :
319
313
estimator_kwargs ["formula" ] = test ["formula" ]
320
314
estimation_model = test ["estimator" ](** estimator_kwargs )
321
- return causal_test_engine , estimation_model
315
+ return estimation_model
322
316
323
317
def _append_to_file (self , line : str , log_level : int = None ):
324
318
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
0 commit comments