5
5
import json
6
6
import logging
7
7
8
- from collections .abc import Iterable , Mapping
8
+ from collections .abc import Mapping
9
9
from dataclasses import dataclass
10
10
from pathlib import Path
11
11
from statistics import StatisticsError
23
23
from causal_testing .specification .variable import Input , Meta , Output
24
24
from causal_testing .testing .causal_test_case import CausalTestCase
25
25
from causal_testing .testing .causal_test_result import CausalTestResult
26
- from causal_testing .testing .causal_test_engine import CausalTestEngine
27
26
from causal_testing .testing .estimators import Estimator
28
27
from causal_testing .testing .base_test_case import BaseTestCase
29
28
from causal_testing .testing .causal_test_adequacy import DataAdequacy
@@ -51,12 +50,12 @@ class JsonUtility:
51
50
def __init__ (self , output_path : str , output_overwrite : bool = False ):
52
51
self .input_paths = None
53
52
self .variables = {"inputs" : {}, "outputs" : {}, "metas" : {}}
54
- self .data = []
55
53
self .test_plan = None
56
54
self .scenario = None
57
55
self .causal_specification = None
58
56
self .output_path = Path (output_path )
59
57
self .check_file_exists (self .output_path , output_overwrite )
58
+ self .data_collector = None
60
59
61
60
def set_paths (self , json_path : str , dag_path : str , data_paths : list [str ] = None ):
62
61
"""
@@ -71,6 +70,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
71
70
72
71
def setup (self , scenario : Scenario ):
73
72
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
73
+ data = []
74
74
self .scenario = scenario
75
75
self ._get_scenario_variables ()
76
76
self .scenario .setup_treatment_variables ()
@@ -82,20 +82,21 @@ def setup(self, scenario: Scenario):
82
82
self .test_plan = json .load (f )
83
83
# Populate the data
84
84
if self .input_paths .data_paths :
85
- self . data = pd .concat ([pd .read_csv (data_file , header = 0 ) for data_file in self .input_paths .data_paths ])
86
- if len (self . data ) == 0 :
85
+ data = pd .concat ([pd .read_csv (data_file , header = 0 ) for data_file in self .input_paths .data_paths ])
86
+ if len (data ) == 0 :
87
87
raise ValueError (
88
88
"No data found, either provide a path to a file containing data or manually populate the .data "
89
89
"attribute with a dataframe before calling .setup()"
90
90
)
91
+ self .data_collector = ObservationalDataCollector (self .scenario , data )
91
92
self ._populate_metas ()
92
93
93
94
def _create_abstract_test_case (self , test , mutates , effects ):
94
95
assert len (test ["mutations" ]) == 1
95
96
treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
96
97
97
98
if not treatment_var .distribution :
98
- fitter = Fitter (self .data [treatment_var .name ], distributions = get_common_distributions ())
99
+ fitter = Fitter (self .data_collector . data [treatment_var .name ], distributions = get_common_distributions ())
99
100
fitter .fit ()
100
101
(dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
101
102
treatment_var .distribution = getattr (scipy .stats , dist )(** params )
@@ -134,15 +135,36 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
134
135
failed , msg = self ._run_concrete_metamorphic_test ()
135
136
# If we have a variable to mutate
136
137
else :
137
- if test ["estimate_type" ] == "coefficient" :
138
- failed , msg = self ._run_coefficient_test (test = test , f_flag = f_flag , effects = effects )
139
- else :
140
- failed , msg = self ._run_metamorphic_tests (
141
- test = test , f_flag = f_flag , effects = effects , mutates = mutates
142
- )
138
+ outcome_variable = next (
139
+ iter (test ["expected_effect" ])
140
+ ) # Take first key from dictionary of expected effect
141
+ base_test_case = BaseTestCase (
142
+ treatment_variable = self .variables ["inputs" ][test ["treatment_variable" ]],
143
+ outcome_variable = self .variables ["outputs" ][outcome_variable ],
144
+ )
145
+
146
+ causal_test_case = CausalTestCase (
147
+ base_test_case = base_test_case ,
148
+ expected_causal_effect = effects [test ["expected_effect" ][outcome_variable ]],
149
+ control_value = test ["control_value" ],
150
+ treatment_value = test ["treatment_value" ],
151
+ estimate_type = test ["estimate_type" ],
152
+ )
153
+
154
+ failed , msg = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
155
+
156
+ # msg = (
157
+ # f"Executing concrete test: {test['name']} \n"
158
+ # + f"treatment variable: {test['treatment_variable']} \n"
159
+ # + f"outcome_variable = {outcome_variable} \n"
160
+ # + f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161
+ # + f"Result: {'FAILED' if failed else 'Passed'}"
162
+ # )
163
+ # print(msg)
164
+ self ._append_to_file (msg , logging .INFO )
143
165
test ["failed" ] = failed
144
166
test ["result" ] = msg
145
- return self .test_plan ["tests" ]
167
+ return self .test_plan ["tests" ]
146
168
147
169
def _run_coefficient_test (self , test : dict , f_flag : bool , effects : dict ):
148
170
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -242,6 +264,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
242
264
details = []
243
265
if "formula" in test :
244
266
self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
267
+
245
268
for concrete_test in concrete_tests :
246
269
failed , result = self ._execute_test_case (concrete_test , test , f_flag )
247
270
details .append (result )
@@ -254,10 +277,10 @@ def _populate_metas(self):
254
277
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
255
278
"""
256
279
for meta in self .scenario .variables_of_type (Meta ):
257
- meta .populate (self .data )
280
+ meta .populate (self .data_collector . data )
258
281
259
282
def _execute_test_case (
260
- self , causal_test_case : CausalTestCase , test : Iterable [ Mapping ] , f_flag : bool
283
+ self , causal_test_case : CausalTestCase , test : Mapping , f_flag : bool
261
284
) -> (bool , CausalTestResult ):
262
285
"""Executes a singular test case, prints the results and returns the test case result
263
286
:param causal_test_case: The concrete test case to be executed
@@ -269,10 +292,10 @@ def _execute_test_case(
269
292
"""
270
293
failed = False
271
294
272
- causal_test_engine , estimation_model = self ._setup_test (
273
- causal_test_case , test , test ["conditions" ] if "conditions" in test else None
295
+ estimation_model = self ._setup_test (causal_test_case = causal_test_case , test = test )
296
+ causal_test_result = causal_test_case .execute_test (
297
+ estimator = estimation_model , data_collector = self .data_collector
274
298
)
275
- causal_test_result = causal_test_engine .execute_test (estimation_model , causal_test_case )
276
299
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
277
300
278
301
if "coverage" in test and test ["coverage" ]:
@@ -300,9 +323,7 @@ def _execute_test_case(
300
323
# logger.warning(" FAILED - expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
301
324
return failed , causal_test_result
302
325
303
- def _setup_test (
304
- self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
305
- ) -> tuple [CausalTestEngine , Estimator ]:
326
+ def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> Estimator :
306
327
"""Create the necessary inputs for a single test case
307
328
:param causal_test_case: The concrete test case to be executed
308
329
:param test: Single JSON test definition stored in a mapping (dict)
@@ -313,12 +334,6 @@ def _setup_test(
313
334
- causal_test_engine - Test Engine instance for the test being run
314
335
- estimation_model - Estimator instance for the test being run
315
336
"""
316
-
317
- data_collector = ObservationalDataCollector (
318
- self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
319
- )
320
- causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
321
-
322
337
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
323
338
treatment_var = causal_test_case .treatment_variable
324
339
minimal_adjustment_set = minimal_adjustment_set - {treatment_var }
@@ -328,14 +343,13 @@ def _setup_test(
328
343
"control_value" : causal_test_case .control_value ,
329
344
"adjustment_set" : minimal_adjustment_set ,
330
345
"outcome" : causal_test_case .outcome_variable .name ,
331
- "df" : causal_test_engine .scenario_execution_data_df ,
332
346
"effect_modifiers" : causal_test_case .effect_modifier_configuration ,
333
347
"alpha" : test ["alpha" ] if "alpha" in test else 0.05 ,
334
348
}
335
349
if "formula" in test :
336
350
estimator_kwargs ["formula" ] = test ["formula" ]
337
351
estimation_model = test ["estimator" ](** estimator_kwargs )
338
- return causal_test_engine , estimation_model
352
+ return estimation_model
339
353
340
354
def _append_to_file (self , line : str , log_level : int = None ):
341
355
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
0 commit comments