5
5
import json
6
6
import logging
7
7
8
- from abc import ABC
9
8
from dataclasses import dataclass
10
9
from pathlib import Path
10
+ from statistics import StatisticsError
11
11
12
12
import pandas as pd
13
13
import scipy
27
27
logger = logging .getLogger (__name__ )
28
28
29
29
30
- class JsonUtility ( ABC ) :
30
+ class JsonUtility :
31
31
"""
32
32
The JsonUtility Class provides the functionality to use structured JSON to setup and run causal tests on the
33
33
CausalTestingFramework.
@@ -40,58 +40,58 @@ class JsonUtility(ABC):
40
40
:attr {Meta} metas: Causal variables representing metavariables.
41
41
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
42
42
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
43
- :attr {Scenario} modelling_scenario :
43
+ :attr {Scenario} scenario :
44
44
:attr {CausalSpecification} causal_specification:
45
45
"""
46
46
47
- def __init__ (self , log_path ):
48
- self .paths = None
47
+ def __init__ (self , output_path : str , output_overwrite : bool = False ):
48
+ self .input_paths = None
49
49
self .variables = None
50
50
self .data = []
51
51
self .test_plan = None
52
- self .modelling_scenario = None
52
+ self .scenario = None
53
53
self .causal_specification = None
54
- self .setup_logger (log_path )
54
+ self .output_path = Path (output_path )
55
+ self .check_file_exists (self .output_path , output_overwrite )
55
56
56
57
def set_paths (self , json_path : str , dag_path : str , data_paths : str ):
57
58
"""
58
59
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
59
60
:param json_path: string path representation to .json file containing test specifications
60
61
:param dag_path: string path representation to the .dot file containing the Causal DAG
61
- :param data_path : string path representation to the data file
62
+ :param data_paths : string path representation to the data files
62
63
"""
63
- self .paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
64
+ self .input_paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
64
65
65
- def set_variables (self , inputs : list [dict ], outputs : list [dict ], metas : list [dict ]):
66
- """Populate the Causal Variables
67
- :param inputs:
68
- :param outputs:
69
- :param metas:
70
- """
71
-
72
- self .variables = CausalVariables (inputs = inputs , outputs = outputs , metas = metas )
73
-
74
- def setup (self ):
66
+ def setup (self , scenario : Scenario ):
75
67
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
76
- self .modelling_scenario = Scenario ( self . variables . inputs + self . variables . outputs + self . variables . metas , None )
77
- self .modelling_scenario .setup_treatment_variables ()
68
+ self .scenario = scenario
69
+ self .scenario .setup_treatment_variables ()
78
70
self .causal_specification = CausalSpecification (
79
- scenario = self .modelling_scenario , causal_dag = CausalDAG (self .paths .dag_path )
71
+ scenario = self .scenario , causal_dag = CausalDAG (self .input_paths .dag_path )
80
72
)
81
73
self ._json_parse ()
82
74
self ._populate_metas ()
83
75
84
76
def _create_abstract_test_case (self , test , mutates , effects ):
85
77
assert len (test ["mutations" ]) == 1
78
+ treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
79
+ if not treatment_var .distribution :
80
+ fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
81
+ fitter .fit ()
82
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
83
+ var .distribution = getattr (scipy .stats , dist )(** params )
84
+ self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
85
+
86
86
abstract_test = AbstractCausalTestCase (
87
- scenario = self .modelling_scenario ,
87
+ scenario = self .scenario ,
88
88
intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
89
- treatment_variable = next ( self . modelling_scenario . variables [ v ] for v in test [ "mutations" ]) ,
89
+ treatment_variable = treatment_var ,
90
90
expected_causal_effect = {
91
- self .modelling_scenario .variables [variable ]: effects [effect ]
91
+ self .scenario .variables [variable ]: effects [effect ]
92
92
for variable , effect in test ["expectedEffect" ].items ()
93
93
},
94
- effect_modifiers = {self .modelling_scenario .variables [v ] for v in test ["effect_modifiers" ]}
94
+ effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
95
95
if "effect_modifiers" in test
96
96
else {},
97
97
estimate_type = test ["estimate_type" ],
@@ -108,14 +108,15 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
108
108
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
109
109
"""
110
110
failures = 0
111
+ msg = ""
111
112
for test in self .test_plan ["tests" ]:
112
113
if "skip" in test and test ["skip" ]:
113
114
continue
114
115
115
116
if test ["estimate_type" ] == "coefficient" :
116
117
base_test_case = BaseTestCase (
117
- treatment_variable = next (self .modelling_scenario .variables [v ] for v in test ["mutations" ]),
118
- outcome_variable = next (self .modelling_scenario .variables [v ] for v in test ["expectedEffect" ]),
118
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
119
+ outcome_variable = next (self .scenario .variables [v ] for v in test ["expectedEffect" ]),
119
120
effect = test ["effect" ],
120
121
)
121
122
assert len (test ["expectedEffect" ]) == 1 , "Can only have one expected effect."
@@ -127,20 +128,29 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
127
128
),
128
129
estimate_type = "coefficient" ,
129
130
effect_modifier_configuration = {
130
- self .modelling_scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
131
+ self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
131
132
},
132
133
)
133
134
]
135
+ failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
136
+ msg = (
137
+ f"Executing test: { test ['name' ]} \n "
138
+ + f" { concrete_tests [0 ]} \n "
139
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
140
+ )
134
141
else :
135
142
abstract_test = self ._create_abstract_test_case (test , mutates , effects )
136
-
137
143
concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
138
- logger .info ("Executing test: %s" , test ["name" ])
139
- logger .info (abstract_test )
140
- logger .info ([abstract_test .treatment_variable .name , abstract_test .treatment_variable .distribution ])
141
- logger .info ("Number of concrete tests for test case: %s" , str (len (concrete_tests )))
142
- failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
143
- logger .info ("%s/%s failed for %s\n " , failures , len (concrete_tests ), test ["name" ])
144
+ failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
145
+ msg = (
146
+ f"Executing test: { test ['name' ]} \n "
147
+ + " abstract_test \n "
148
+ + f" { abstract_test } \n "
149
+ + f" { abstract_test .treatment_variable .name } ,{ abstract_test .treatment_variable .distribution } \n "
150
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
151
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
152
+ )
153
+ self ._append_to_file (msg , logging .INFO )
144
154
145
155
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
146
156
failures = 0
@@ -154,9 +164,9 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
154
164
155
165
def _json_parse (self ):
156
166
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
157
- with open (self .paths .json_path , encoding = "utf-8" ) as f :
167
+ with open (self .input_paths .json_path , encoding = "utf-8" ) as f :
158
168
self .test_plan = json .load (f )
159
- for data_file in self .paths .data_paths :
169
+ for data_file in self .input_paths .data_paths :
160
170
df = pd .read_csv (data_file , header = 0 )
161
171
self .data .append (df )
162
172
self .data = pd .concat (self .data )
@@ -165,20 +175,9 @@ def _populate_metas(self):
165
175
"""
166
176
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
167
177
"""
168
- for meta in self .variables . metas :
178
+ for meta in self .scenario . variables_of_type ( Meta ) :
169
179
meta .populate (self .data )
170
180
171
- for var in self .variables .metas + self .variables .outputs :
172
- if not var .distribution :
173
- try :
174
- fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
175
- fitter .fit ()
176
- (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
177
- var .distribution = getattr (scipy .stats , dist )(** params )
178
- logger .info (var .name + f" { dist } ({ params } )" )
179
- except :
180
- logger .warn (f"Could not fit distriubtion for { var .name } ." )
181
-
182
181
def _execute_test_case (
183
182
self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool , conditions : list [str ]
184
183
) -> bool :
@@ -191,7 +190,6 @@ def _execute_test_case(
191
190
"""
192
191
failed = False
193
192
194
- print (causal_test_case )
195
193
causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator , conditions )
196
194
causal_test_result = causal_test_engine .execute_test (
197
195
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
@@ -207,12 +205,13 @@ def _execute_test_case(
207
205
)
208
206
else :
209
207
result_string = f"{ causal_test_result .test_value .value } no confidence intervals"
210
- if f_flag :
211
- assert test_passes , (
212
- f"{ causal_test_case } \n FAILED - expected { causal_test_case .expected_causal_effect } , "
213
- f"got { result_string } "
214
- )
208
+
215
209
if not test_passes :
210
+ if f_flag :
211
+ raise StatisticsError (
212
+ f"{ causal_test_case } \n FAILED - expected { causal_test_case .expected_causal_effect } , "
213
+ f"got { result_string } "
214
+ )
216
215
failed = True
217
216
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
218
217
return failed
@@ -228,7 +227,7 @@ def _setup_test(
228
227
"""
229
228
230
229
data_collector = ObservationalDataCollector (
231
- self .modelling_scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
230
+ self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
232
231
)
233
232
causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
234
233
@@ -256,15 +255,32 @@ def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: dis
256
255
"""
257
256
return
258
257
258
+ def _append_to_file (self , line : str , log_level : int = None ):
259
+ """Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
260
+ logging level.
261
+ :param line: The line or lines of text to be appended to the file
262
+ :param log_level: An integer representing the logging level as specified by pythons inbuilt logging module. It
263
+ is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
264
+ """
265
+ with open (self .output_path , "a" , encoding = "utf-8" ) as f :
266
+ f .write (
267
+ line + "\n " ,
268
+ )
269
+ if log_level :
270
+ logger .log (level = log_level , msg = line )
271
+
259
272
@staticmethod
260
- def setup_logger (log_path : str ):
261
- """Setups up logging instance for the module and adds a FileHandler stream so all stdout prints are also
262
- sent to the logfile
263
- :param log_path: Path specifying location and name of the logging file to be used
273
+ def check_file_exists (output_path : Path , overwrite : bool ):
274
+ """Method that checks if the given path to an output file already exists. If overwrite is true the check is
275
+ passed.
276
+ :param output_path: File path for the output file of the JSON Frontend
277
+ :param overwrite: bool that if true, the current file can be overwritten
264
278
"""
265
- setup_log = logging .getLogger (__name__ )
266
- file_handler = logging .FileHandler (Path (log_path ))
267
- setup_log .addHandler (file_handler )
279
+ if output_path .is_file ():
280
+ if overwrite :
281
+ output_path .unlink ()
282
+ else :
283
+ raise FileExistsError (f"Chosen file output ({ output_path } ) already exists" )
268
284
269
285
@staticmethod
270
286
def get_args (test_args = None ) -> argparse .Namespace :
@@ -280,6 +296,12 @@ def get_args(test_args=None) -> argparse.Namespace:
280
296
help = "if included, the script will stop if a test fails" ,
281
297
action = "store_true" ,
282
298
)
299
+ parser .add_argument (
300
+ "-w" ,
301
+ help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
302
+ "careful" ,
303
+ action = "store_true" ,
304
+ )
283
305
parser .add_argument (
284
306
"--log_path" ,
285
307
help = "Specify a directory to change the location of the log file" ,
@@ -323,17 +345,17 @@ def __init__(self, json_path: str, dag_path: str, data_paths: str):
323
345
self .data_paths = [Path (path ) for path in data_paths ]
324
346
325
347
326
- @dataclass ()
348
+ @dataclass
327
349
class CausalVariables :
328
350
"""
329
- A dataclass that converts
351
+ A dataclass that converts lists of dictionaries into lists of Causal Variables
330
352
"""
331
353
332
- inputs : list [Input ]
333
- outputs : list [Output ]
334
- metas : list [Meta ]
335
-
336
354
def __init__ (self , inputs : list [dict ], outputs : list [dict ], metas : list [dict ]):
337
355
self .inputs = [Input (** i ) for i in inputs ]
338
356
self .outputs = [Output (** o ) for o in outputs ]
339
357
self .metas = [Meta (** m ) for m in metas ] if metas else []
358
+
359
+ def __iter__ (self ):
360
+ for var in self .inputs + self .outputs + self .metas :
361
+ yield var
0 commit comments