Skip to content

Commit ea6581b

Browse files
Add concrete test case method
1 parent dbfe5cd commit ea6581b

File tree

3 files changed

+61
-33
lines changed

3 files changed

+61
-33
lines changed

causal_testing/json_front/json_class.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from causal_testing.specification.causal_specification import CausalSpecification
2020
from causal_testing.specification.scenario import Scenario
2121
from causal_testing.specification.variable import Input, Meta, Output
22+
from causal_testing.testing.base_test_case import BaseTestCase
2223
from causal_testing.testing.causal_test_case import CausalTestCase
2324
from causal_testing.testing.causal_test_engine import CausalTestEngine
2425
from causal_testing.testing.estimators import Estimator
@@ -73,6 +74,60 @@ def setup(self, scenario: Scenario):
7374
self._json_parse()
7475
self._populate_metas()
7576

77+
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
78+
"""Runs and evaluates each test case specified in the JSON input
79+
80+
:param effects: Dictionary mapping effect class instances to string representations.
81+
:param mutates: Dictionary mapping mutation functions to string representations.
82+
:param estimators: Dictionary mapping estimator classes to string representations.
83+
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
84+
"""
85+
failures = 0
86+
for test in self.test_plan["tests"]:
87+
if "skip" in test and test["skip"]:
88+
continue
89+
if "mutates" in test:
90+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
91+
92+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
93+
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
94+
msg = (
95+
f"Executing test: {test['name']}\n"
96+
+ "abstract_test\n"
97+
+ f"{abstract_test}\n"
98+
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution}\n"
99+
+ f"Number of concrete tests for test case: {str(len(concrete_tests))}\n"
100+
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
101+
)
102+
self._append_to_file(msg, logging.INFO)
103+
else:
104+
outcome_variable = next(iter(test['expectedEffect'])) # Take first key from dictionary of expected effect
105+
expected_effect = effects[test['expectedEffect'][outcome_variable]]
106+
base_test_case = BaseTestCase(treatment_variable=self.variables["inputs"][test["treatment_variable"]],
107+
outcome_variable=self.variables["outputs"][outcome_variable])
108+
109+
causal_test_case = CausalTestCase(base_test_case=base_test_case,
110+
expected_causal_effect=expected_effect,
111+
control_value=test["control_value"],
112+
treatment_value=test["treatment_value"],
113+
estimate_type=test["estimate_type"])
114+
115+
116+
if self._execute_test_case(causal_test_case=causal_test_case,
117+
estimator=estimators[test["estimator"]],
118+
f_flag=f_flag):
119+
result = "failed"
120+
else:
121+
result = "passed"
122+
123+
msg = (
124+
f"Executing test: {test['name']} \n"
125+
+ f"treatment variable: {test['treatment_variable']} \n"
126+
+ f"outcome_variable = {outcome_variable} \n"
127+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
128+
+ f"result - {result} \n"
129+
)
130+
self._append_to_file(msg, logging.INFO)
76131
def _create_abstract_test_case(self, test, mutates, effects):
77132
assert len(test["mutations"]) == 1
78133
abstract_test = AbstractCausalTestCase(
@@ -91,32 +146,6 @@ def _create_abstract_test_case(self, test, mutates, effects):
91146
)
92147
return abstract_test
93148

94-
def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
95-
"""Runs and evaluates each test case specified in the JSON input
96-
97-
:param effects: Dictionary mapping effect class instances to string representations.
98-
:param mutates: Dictionary mapping mutation functions to string representations.
99-
:param estimators: Dictionary mapping estimator classes to string representations.
100-
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
101-
"""
102-
failures = 0
103-
for test in self.test_plan["tests"]:
104-
if "skip" in test and test["skip"]:
105-
continue
106-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
107-
108-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
109-
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
110-
msg = (
111-
f"Executing test: {test['name']} \n"
112-
+ "abstract_test \n"
113-
+ f"{abstract_test} \n"
114-
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution} \n"
115-
+ f"Number of concrete tests for test case: {str(len(concrete_tests))} \n"
116-
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
117-
)
118-
self._append_to_file(msg, logging.INFO)
119-
120149
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
121150
failures = 0
122151
for concrete_test in concrete_tests:
@@ -157,7 +186,6 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
157186
:rtype: bool
158187
"""
159188
failed = False
160-
161189
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
162190
causal_test_result = causal_test_engine.execute_test(
163191
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
@@ -228,7 +256,7 @@ def _append_to_file(self, line: str, log_level: int = None):
228256
"""
229257
with open(self.output_path, "a", encoding="utf-8") as f:
230258
f.write(
231-
line + "\n",
259+
line,
232260
)
233261
if log_level:
234262
logger.log(level=log_level, msg=line)

examples/poisson/example_run_causal_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_run_causal_tests():
174174
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
175175
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
176176

177-
json_utility.generate_tests(effects, mutates, estimators, False)
177+
json_utility.run_json_tests(effects, mutates, estimators, False)
178178

179179

180180
if __name__ == "__main__":
@@ -187,4 +187,4 @@ def test_run_causal_tests():
187187
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
188188
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
189189

190-
json_utility.generate_tests(effects, mutates, estimators, args.f)
190+
json_utility.run_json_tests(effects, mutates, estimators, args.f)

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def test_f_flag(self):
9696
}
9797
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
9898
with self.assertRaises(StatisticsError):
99-
self.json_class.generate_tests(effects, mutates, estimators, True)
99+
self.json_class.run_json_tests(effects, mutates, estimators, True)
100100

101-
def test_generate_tests_from_json(self):
101+
def test_run_json_tests_from_json(self):
102102
example_test = {
103103
"tests": [
104104
{
@@ -120,7 +120,7 @@ def test_generate_tests_from_json(self):
120120
}
121121
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
122122

123-
self.json_class.generate_tests(effects, mutates, estimators, False)
123+
self.json_class.run_json_tests(effects, mutates, estimators, False)
124124

125125
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
126126
with open("temp_out.txt", 'r') as reader:

0 commit comments

Comments
 (0)