6
6
import logging
7
7
8
8
from dataclasses import dataclass
9
+ from enum import Enum
9
10
from pathlib import Path
10
11
11
12
import pandas as pd
@@ -42,14 +43,15 @@ class JsonUtility:
42
43
:attr {CausalSpecification} causal_specification:
43
44
"""
44
45
45
- def __init__ (self , output_path ):
46
- self .paths = None
46
+ def __init__ (self , output_path : str , output_overwrite : bool = False ):
47
+ self .input_paths = None
47
48
self .variables = None
48
49
self .data = []
49
50
self .test_plan = None
50
51
self .scenario = None
51
52
self .causal_specification = None
52
- self .check_file_exists (Path (output_path ))
53
+ self .output_path = Path (output_path )
54
+ self .check_file_exists (self .output_path , output_overwrite )
53
55
54
56
def set_paths (self , json_path : str , dag_path : str , data_paths : str ):
55
57
"""
@@ -58,14 +60,14 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
58
60
:param dag_path: string path representation to the .dot file containing the Causal DAG
59
61
:param data_paths: string path representation to the data files
60
62
"""
61
- self .paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
63
+ self .input_paths = JsonClassPaths (json_path = json_path , dag_path = dag_path , data_paths = data_paths )
62
64
63
65
def setup (self , scenario : Scenario ):
64
66
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
65
67
self .scenario = scenario
66
68
self .scenario .setup_treatment_variables ()
67
69
self .causal_specification = CausalSpecification (
68
- scenario = self .scenario , causal_dag = CausalDAG (self .paths .dag_path )
70
+ scenario = self .scenario , causal_dag = CausalDAG (self .input_paths .dag_path )
69
71
)
70
72
self ._json_parse ()
71
73
self ._populate_metas ()
@@ -120,9 +122,9 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
120
122
121
123
def _json_parse (self ):
122
124
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
123
- with open (self .paths .json_path , encoding = "utf-8" ) as f :
125
+ with open (self .input_paths .json_path , encoding = "utf-8" ) as f :
124
126
self .test_plan = json .load (f )
125
- for data_file in self .paths .data_paths :
127
+ for data_file in self .input_paths .data_paths :
126
128
df = pd .read_csv (data_file , header = 0 )
127
129
self .data .append (df )
128
130
self .data = pd .concat (self .data )
@@ -173,7 +175,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
173
175
)
174
176
if not test_passes :
175
177
failed = True
176
- logger . warning ( " FAILED- expected %s , got %s " , causal_test_case . expected_causal_effect , result_string )
178
+ self . append_to_file ( f" FAILED- expected { causal_test_case . expected_causal_effect } , got { result_string } " , logging . WARNING )
177
179
return failed
178
180
179
181
def _setup_test (self , causal_test_case : CausalTestCase , estimator : Estimator ) -> tuple [CausalTestEngine , Estimator ]:
@@ -210,11 +212,16 @@ def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: dis
210
212
:param estimation_model: estimator model instance for the current running test.
211
213
"""
212
214
return
215
+ def append_to_file (self , line : str , log_level : int = None ):
216
+ with open (self .output_path , "a" ) as f :
217
+ f .write ('\n ' .join (line ))
218
+ if log_level :
219
+ logger .log (level = log_level , msg = line )
213
220
214
221
@staticmethod
215
- def check_file_exists (output_path : Path ):
216
- if output_path .is_file ():
217
- raise FileExistsError ("Chosen file output already exists" )
222
+ def check_file_exists (output_path : Path , overwrite : bool ):
223
+ if not overwrite and output_path .is_file ():
224
+ raise FileExistsError (f "Chosen file output ( { output_path } ) already exists" )
218
225
219
226
@staticmethod
220
227
def get_args (test_args = None ) -> argparse .Namespace :
@@ -291,4 +298,4 @@ def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
291
298
292
299
def __iter__ (self ):
293
300
for var in self .inputs + self .outputs + self .metas :
294
- yield var
301
+ yield var
0 commit comments