Skip to content

Commit e98a81f

Browse files
Merge branch 'main' into json_moddeling_assumption_method
# Conflicts: # causal_testing/json_front/json_class.py
2 parents fcd934f + b68aee3 commit e98a81f

File tree

3 files changed

+90
-33
lines changed

3 files changed

+90
-33
lines changed

causal_testing/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@
1212

1313
logger = logging.getLogger(__name__)
1414
logger.setLevel(logging.INFO)
15-
logger.addHandler(logging.StreamHandler())

causal_testing/json_front/json_class.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterable, Mapping
99
from dataclasses import dataclass
1010
from pathlib import Path
11+
from statistics import StatisticsError
1112

1213
import pandas as pd
1314
import scipy
@@ -43,14 +44,15 @@ class JsonUtility:
4344
:attr {CausalSpecification} causal_specification:
4445
"""
4546

46-
def __init__(self, log_path):
47-
self.paths = None
47+
def __init__(self, output_path: str, output_overwrite: bool = False):
48+
self.input_paths = None
4849
self.variables = None
4950
self.data = []
5051
self.test_plan = None
5152
self.scenario = None
5253
self.causal_specification = None
53-
self.setup_logger(log_path)
54+
self.output_path = Path(output_path)
55+
self.check_file_exists(self.output_path, output_overwrite)
5456

5557
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
5658
"""
@@ -59,14 +61,14 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
5961
:param dag_path: string path representation to the .dot file containing the Causal DAG
6062
:param data_paths: string path representation to the data files
6163
"""
62-
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
64+
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6365

6466
def setup(self, scenario: Scenario):
6567
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
6668
self.scenario = scenario
6769
self.scenario.setup_treatment_variables()
6870
self.causal_specification = CausalSpecification(
69-
scenario=self.scenario, causal_dag=CausalDAG(self.paths.dag_path)
71+
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
7072
)
7173
self._json_parse()
7274
self._populate_metas()
@@ -104,12 +106,16 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
104106
abstract_test = self._create_abstract_test_case(test, mutates, effects)
105107

106108
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
107-
logger.info("Executing test: %s", test["name"])
108-
logger.info(abstract_test)
109-
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
110-
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
111109
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
112-
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
110+
msg = (
111+
f"Executing test: {test['name']} \n"
112+
+ "abstract_test \n"
113+
+ f"{abstract_test} \n"
114+
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n"
115+
+ f"Number of concrete tests for test case: {str(len(concrete_tests))} \n"
116+
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
117+
)
118+
self._append_to_file(msg, logging.INFO)
113119

114120
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
115121
failures = 0
@@ -122,9 +128,9 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
122128

123129
def _json_parse(self):
124130
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
125-
with open(self.paths.json_path, encoding="utf-8") as f:
131+
with open(self.input_paths.json_path, encoding="utf-8") as f:
126132
self.test_plan = json.load(f)
127-
for data_file in self.paths.data_paths:
133+
for data_file in self.input_paths.data_paths:
128134
df = pd.read_csv(data_file, header=0)
129135
self.data.append(df)
130136
self.data = pd.concat(self.data)
@@ -141,7 +147,7 @@ def _populate_metas(self):
141147
fitter.fit()
142148
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
143149
var.distribution = getattr(scipy.stats, dist)(**params)
144-
logger.info(var.name + f" {dist}({params})")
150+
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
145151

146152
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
147153
"""Executes a singular test case, prints the results and returns the test case result
@@ -169,12 +175,13 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
169175
)
170176
else:
171177
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
172-
if f_flag:
173-
assert test_passes, (
174-
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
175-
f"got {result_string}"
176-
)
178+
177179
if not test_passes:
180+
if f_flag:
181+
raise StatisticsError(
182+
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
183+
f"got {result_string}"
184+
)
178185
failed = True
179186
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
180187
return failed
@@ -218,15 +225,34 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
218225

219226
return causal_test_engine, estimation_model
220227

228+
229+
230+
def _append_to_file(self, line: str, log_level: int = None):
231+
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
232+
logging level.
233+
:param line: The line or lines of text to be appended to the file
234+
:param log_level: An integer representing the logging level as specified by pythons inbuilt logging module. It
235+
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
236+
"""
237+
with open(self.output_path, "a", encoding="utf-8") as f:
238+
f.write(
239+
line + "\n",
240+
)
241+
if log_level:
242+
logger.log(level=log_level, msg=line)
243+
221244
@staticmethod
222-
def setup_logger(log_path: str):
223-
"""Setups up logging instance for the module and adds a FileHandler stream so all stdout prints are also
224-
sent to the logfile
225-
:param log_path: Path specifying location and name of the logging file to be used
245+
def check_file_exists(output_path: Path, overwrite: bool):
246+
"""Method that checks if the given path to an output file already exists. If overwrite is true the check is
247+
passed.
248+
:param output_path: File path for the output file of the JSON Frontend
249+
:param overwrite: bool that if true, the current file can be overwritten
226250
"""
227-
setup_log = logging.getLogger(__name__)
228-
file_handler = logging.FileHandler(Path(log_path))
229-
setup_log.addHandler(file_handler)
251+
if output_path.is_file():
252+
if overwrite:
253+
output_path.unlink()
254+
else:
255+
raise FileExistsError(f"Chosen file output ({output_path}) already exists")
230256

231257
@staticmethod
232258
def get_args(test_args=None) -> argparse.Namespace:
@@ -242,6 +268,12 @@ def get_args(test_args=None) -> argparse.Namespace:
242268
help="if included, the script will stop if a test fails",
243269
action="store_true",
244270
)
271+
parser.add_argument(
272+
"-w",
273+
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
274+
"careful",
275+
action="store_true",
276+
)
245277
parser.add_argument(
246278
"--log_path",
247279
help="Specify a directory to change the location of the log file",

tests/json_front_tests/test_json_class.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
from pathlib import Path
3+
from statistics import StatisticsError
34
import scipy
45
import csv
56
import json
@@ -29,7 +30,7 @@ def setUp(self) -> None:
2930
self.json_path = str(test_data_dir_path / json_file_name)
3031
self.dag_path = str(test_data_dir_path / dag_file_name)
3132
self.data_path = [str(test_data_dir_path / data_file_name)]
32-
self.json_class = JsonUtility("logs.log")
33+
self.json_class = JsonUtility("temp_out.txt", True)
3334
self.example_distribution = scipy.stats.uniform(1, 10)
3435
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
3536
self.output_dict_list = [{"name": "test_output", "datatype": float}]
@@ -41,9 +42,9 @@ def setUp(self) -> None:
4142
self.json_class.setup(self.scenario)
4243

4344
def test_setting_paths(self):
44-
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
45-
self.assertEqual(self.json_class.paths.dag_path, Path(self.dag_path))
46-
self.assertEqual(self.json_class.paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
45+
self.assertEqual(self.json_class.input_paths.json_path, Path(self.json_path))
46+
self.assertEqual(self.json_class.input_paths.dag_path, Path(self.dag_path))
47+
self.assertEqual(self.json_class.input_paths.data_paths, [Path(self.data_path[0])]) # Needs to be list of Paths
4748

4849
def test_set_inputs(self):
4950
ctf_input = [Input("test_input", float, self.example_distribution)]
@@ -73,6 +74,30 @@ def test_setup_scenario(self):
7374
def test_setup_causal_specification(self):
7475
self.assertIsInstance(self.json_class.causal_specification, CausalSpecification)
7576

77+
def test_f_flag(self):
78+
example_test = {
79+
"tests": [
80+
{
81+
"name": "test1",
82+
"mutations": {"test_input": "Increase"},
83+
"estimator": "LinearRegressionEstimator",
84+
"estimate_type": "ate",
85+
"effect_modifiers": [],
86+
"expectedEffect": {"test_output": "NoEffect"},
87+
"skip": False,
88+
}
89+
]
90+
}
91+
self.json_class.test_plan = example_test
92+
effects = {"NoEffect": NoEffect()}
93+
mutates = {
94+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
95+
> self.json_class.scenario.variables[x].z3
96+
}
97+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
98+
with self.assertRaises(StatisticsError):
99+
self.json_class.generate_tests(effects, mutates, estimators, True)
100+
76101
def test_generate_tests_from_json(self):
77102
example_test = {
78103
"tests": [
@@ -95,11 +120,12 @@ def test_generate_tests_from_json(self):
95120
}
96121
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
97122

98-
with self.assertLogs() as captured:
99-
self.json_class.generate_tests(effects, mutates, estimators, False)
123+
self.json_class.generate_tests(effects, mutates, estimators, False)
100124

101125
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
102-
self.assertIn("failed", captured.records[-1].getMessage())
126+
with open("temp_out.txt", 'r') as reader:
127+
temp_out = reader.readlines()
128+
self.assertIn("failed", temp_out[-1])
103129

104130
def tearDown(self) -> None:
105131
pass

0 commit comments

Comments
 (0)