Skip to content

Commit 7741448

Browse files
Add append_to_file method
1 parent 999352e commit 7741448

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

causal_testing/json_front/json_class.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
from dataclasses import dataclass
9+
from enum import Enum
910
from pathlib import Path
1011

1112
import pandas as pd
@@ -42,14 +43,15 @@ class JsonUtility:
4243
:attr {CausalSpecification} causal_specification:
4344
"""
4445

45-
def __init__(self, output_path):
46-
self.paths = None
46+
def __init__(self, output_path: str, output_overwrite: bool = False):
47+
self.input_paths = None
4748
self.variables = None
4849
self.data = []
4950
self.test_plan = None
5051
self.scenario = None
5152
self.causal_specification = None
52-
self.check_file_exists(Path(output_path))
53+
self.output_path = Path(output_path)
54+
self.check_file_exists(self.output_path, output_overwrite)
5355

5456
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
5557
"""
@@ -58,14 +60,14 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
5860
:param dag_path: string path representation to the .dot file containing the Causal DAG
5961
:param data_paths: string path representation to the data files
6062
"""
61-
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
63+
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6264

6365
def setup(self, scenario: Scenario):
6466
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
6567
self.scenario = scenario
6668
self.scenario.setup_treatment_variables()
6769
self.causal_specification = CausalSpecification(
68-
scenario=self.scenario, causal_dag=CausalDAG(self.paths.dag_path)
70+
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
6971
)
7072
self._json_parse()
7173
self._populate_metas()
@@ -120,9 +122,9 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
120122

121123
def _json_parse(self):
122124
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
123-
with open(self.paths.json_path, encoding="utf-8") as f:
125+
with open(self.input_paths.json_path, encoding="utf-8") as f:
124126
self.test_plan = json.load(f)
125-
for data_file in self.paths.data_paths:
127+
for data_file in self.input_paths.data_paths:
126128
df = pd.read_csv(data_file, header=0)
127129
self.data.append(df)
128130
self.data = pd.concat(self.data)
@@ -173,7 +175,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
173175
)
174176
if not test_passes:
175177
failed = True
176-
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
178+
self.append_to_file(f"FAILED- expected {causal_test_case.expected_causal_effect}, got {result_string}", logging.WARNING)
177179
return failed
178180

179181
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
@@ -210,11 +212,16 @@ def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: dis
210212
:param estimation_model: estimator model instance for the current running test.
211213
"""
212214
return
215+
def append_to_file(self, line: str, log_level: int = None):
216+
with open(self.output_path, "a") as f:
217+
f.write('\n'.join(line))
218+
if log_level:
219+
logger.log(level=log_level, msg=line)
213220

214221
@staticmethod
215-
def check_file_exists(output_path: Path):
216-
if output_path.is_file():
217-
raise FileExistsError("Chosen file output already exists")
222+
def check_file_exists(output_path: Path, overwrite: bool):
223+
if not overwrite and output_path.is_file():
224+
raise FileExistsError(f"Chosen file output ({output_path}) already exists")
218225

219226
@staticmethod
220227
def get_args(test_args=None) -> argparse.Namespace:
@@ -291,4 +298,4 @@ def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
291298

292299
def __iter__(self):
293300
for var in self.inputs + self.outputs + self.metas:
294-
yield var
301+
yield var

0 commit comments

Comments
 (0)