Skip to content

Commit 1d1ea3b

Browse files
Update json tests
1 parent 41a8af1 commit 1d1ea3b

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

causal_testing/json_front/json_class.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,9 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
110110
f"abstract_test \n" + \
111111
f"{abstract_test} \n" + \
112112
f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n" + \
113-
f"Number of concrete tests for test case: {str(len(concrete_tests))}"
114-
115-
self.append_to_file(msg, logging.INFO)
116-
113+
f"Number of concrete tests for test case: {str(len(concrete_tests))} \n" + \
114+
f"{failures}/{len(concrete_tests)} failed for {test['name']}"
115+
self._append_to_file(msg, logging.INFO)
117116

118117
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
119118
failures = 0
@@ -123,7 +122,6 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):
123122
failures += 1
124123
return failures
125124

126-
127125
def _json_parse(self):
128126
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
129127
with open(self.input_paths.json_path, encoding="utf-8") as f:
@@ -133,7 +131,6 @@ def _json_parse(self):
133131
self.data.append(df)
134132
self.data = pd.concat(self.data)
135133

136-
137134
def _populate_metas(self):
138135
"""
139136
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
@@ -146,8 +143,7 @@ def _populate_metas(self):
146143
fitter.fit()
147144
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
148145
var.distribution = getattr(scipy.stats, dist)(**params)
149-
self.append_to_file(var.name + f" {dist}({params})", logging.INFO)
150-
146+
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
151147

152148
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
153149
"""Executes a singular test case, prints the results and returns the test case result
@@ -181,11 +177,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
181177
)
182178
if not test_passes:
183179
failed = True
184-
self.append_to_file(f"FAILED- expected {causal_test_case.expected_causal_effect}, got {result_string}",
185-
logging.WARNING)
180+
self._append_to_file(f"FAILED- expected {causal_test_case.expected_causal_effect}, got {result_string}",
181+
logging.WARNING)
186182
return failed
187183

188-
189184
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
190185
"""Create the necessary inputs for a single test case
191186
:param causal_test_case: The concrete test case to be executed
@@ -214,28 +209,24 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
214209

215210
return causal_test_engine, estimation_model
216211

217-
218212
def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: disable=unused-argument
219213
"""Optional abstract method where user functionality can be written to determine what assumptions are required
220214
for specific test cases
221215
:param estimation_model: estimator model instance for the current running test.
222216
"""
223217
return
224218

225-
226-
def append_to_file(self, line: str, log_level: int = None):
219+
def _append_to_file(self, line: str, log_level: int = None):
227220
with open(self.output_path, "a") as f:
228-
f.write(line+"\n")
221+
f.write(line + "\n")
229222
if log_level:
230223
logger.log(level=log_level, msg=line)
231224

232-
233225
@staticmethod
234226
def check_file_exists(output_path: Path, overwrite: bool):
235227
if not overwrite and output_path.is_file():
236228
raise FileExistsError(f"Chosen file output ({output_path}) already exists")
237229

238-
239230
@staticmethod
240231
def get_args(test_args=None) -> argparse.Namespace:
241232
"""Command-line arguments

tests/json_front_tests/test_json_class.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setUp(self) -> None:
2929
self.json_path = str(test_data_dir_path / json_file_name)
3030
self.dag_path = str(test_data_dir_path / dag_file_name)
3131
self.data_path = [str(test_data_dir_path / data_file_name)]
32-
self.json_class = JsonUtility("logs.log")
32+
self.json_class = JsonUtility("temp_out.txt", True)
3333
self.example_distribution = scipy.stats.uniform(1, 10)
3434
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
3535
self.output_dict_list = [{"name": "test_output", "datatype": float}]
@@ -95,11 +95,12 @@ def test_generate_tests_from_json(self):
9595
}
9696
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
9797

98-
with self.assertLogs() as captured:
99-
self.json_class.generate_tests(effects, mutates, estimators, False)
98+
self.json_class.generate_tests(effects, mutates, estimators, False)
10099

101100
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
102-
self.assertIn("failed", captured.records[-1].getMessage())
101+
with open("temp_out.txt", 'r') as reader:
102+
temp_out = reader.readlines()
103+
self.assertIn("failed", temp_out[-1])
103104

104105
def tearDown(self) -> None:
105106
pass

0 commit comments

Comments
 (0)