Skip to content

Commit d8c21dc

Browse files
Merge pull request #89 from CITCOM-project/test-coverage-json
Json Frontend Tests
2 parents 030b655 + f2a12fd commit d8c21dc

File tree

9 files changed

+203
-63
lines changed

9 files changed

+203
-63
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def _generate_concrete_tests(
144144
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
145145

146146
def generate_concrete_tests(
147-
self, sample_size: int, target_ks_score: float = None, rct: bool = False, seed: int = 0, hard_max: int = 1000
147+
self,
148+
sample_size: int,
149+
target_ks_score: float = None,
150+
rct: bool = False,
151+
seed: int = 0,
152+
hard_max: int = 1000,
148153
) -> tuple[list[CausalTestCase], pd.DataFrame]:
149154
"""Generates a list of `num` concrete test cases.
150155

causal_testing/json_front/json_class.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,16 @@ def set_path(self, json_path: str, dag_path: str, data_path: str):
6767
self.dag_path = Path(dag_path)
6868
self.data_path = Path(data_path)
6969

70-
def set_variables(self, inputs: dict, outputs: dict, metas: dict, distributions: dict, populates: dict):
70+
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
71+
7172
"""Populate the Causal Variables
7273
:param inputs:
7374
:param outputs:
7475
:param metas:
75-
:param distributions:
76-
:param populates:
7776
"""
78-
self.inputs = [Input(i["name"], i["type"], distributions[i["distribution"]]) for i in inputs]
77+
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
7978
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
80-
self.metas = [Meta(i["name"], i["type"], populates[i["populate"]]) for i in metas] if metas else []
79+
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8180

8281
def setup(self):
8382
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
@@ -89,54 +88,58 @@ def setup(self):
8988
self._json_parse()
9089
self._populate_metas()
9190

92-
def execute_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
91+
def _create_abstract_test_case(self, test, mutates, effects):
92+
abstract_test = AbstractCausalTestCase(
93+
scenario=self.modelling_scenario,
94+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
95+
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
96+
expected_causal_effect={
97+
self.modelling_scenario.variables[variable]: effects[effect]
98+
for variable, effect in test["expectedEffect"].items()
99+
},
100+
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
101+
if "effect_modifiers" in test
102+
else {},
103+
estimate_type=test["estimate_type"],
104+
)
105+
return abstract_test
106+
107+
def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
93108
"""Runs and evaluates each test case specified in the JSON input
94109
95110
:param effects: Dictionary mapping effect class instances to string representations.
96111
:param mutates: Dictionary mapping mutation functions to string representations.
97112
:param estimators: Dictionary mapping estimator classes to string representations.
98113
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
99114
"""
100-
executed_tests = 0
101115
failures = 0
102116
for test in self.test_plan["tests"]:
103117
if "skip" in test and test["skip"]:
104118
continue
105-
106-
abstract_test = AbstractCausalTestCase(
107-
scenario=self.modelling_scenario,
108-
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
109-
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
110-
expected_causal_effect={
111-
self.modelling_scenario.variables[variable]: effects[effect]
112-
for variable, effect in test["expectedEffect"].items()
113-
},
114-
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
115-
if "effect_modifiers" in test
116-
else {},
117-
estimate_type=test["estimate_type"],
118-
)
119+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
119120

120121
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
121122
logger.info("Executing test: %s", test["name"])
122123
logger.info(abstract_test)
123124
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
124125
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
125-
for concrete_test in concrete_tests:
126-
executed_tests += 1
127-
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
128-
if failed:
129-
failures += 1
126+
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
127+
128+
logger.info(f"{failures}/{len(concrete_tests)} failed")
130129

131-
logger.info("{%d}/{%d} failed", failures, executed_tests)
130+
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
131+
failures = 0
132+
for concrete_test in concrete_tests:
133+
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
134+
if failed:
135+
failures += 1
136+
return failures
132137

133138
def _json_parse(self):
134-
"""Parse a JSON input file into inputs, outputs, metas and a test plan
135-
:param distributions: dictionary of user defined scipy distributions
136-
:param populates: dictionary of user defined populate functions
137-
"""
138-
with open(self.json_path, encoding="UTF-8") as file:
139-
self.test_plan = json.load(file)
139+
140+
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
141+
with open(self.json_path) as f:
142+
self.test_plan = json.load(f)
140143

141144
self.data = pd.read_csv(self.data_path)
142145

@@ -187,7 +190,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
187190
if not test_passes:
188191
failed = True
189192
logger.warning(
190-
" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, causal_test_result.ate
193+
" FAILED- expected %s, got %s",
194+
causal_test_case.expected_causal_effect,
195+
causal_test_result.ate,
191196
)
192197
return failed
193198

@@ -235,25 +240,37 @@ def setup_logger(log_path: str):
235240
setup_log.addHandler(file_handler)
236241

237242
@staticmethod
238-
def get_args() -> argparse.Namespace:
243+
def get_args(test_args=None) -> argparse.Namespace:
239244
"""Command-line arguments
240245
241246
:return: parsed command line arguments
242247
"""
243248
parser = argparse.ArgumentParser(
244249
description="A script for parsing json config files for the Causal Testing Framework"
245250
)
246-
parser.add_argument("-f", help="if included, the script will stop if a test fails", action="store_true")
251+
parser.add_argument(
252+
"-f",
253+
help="if included, the script will stop if a test fails",
254+
action="store_true",
255+
)
247256
parser.add_argument(
248257
"--log_path",
249258
help="Specify a directory to change the location of the log file",
250259
default="./json_frontend.log",
251260
)
252-
parser.add_argument("--data_path", help="Specify path to file containing runtime data", required=True)
253261
parser.add_argument(
254-
"--dag_path", help="Specify path to file containing the DAG, normally a .dot file", required=True
262+
"--data_path",
263+
help="Specify path to file containing runtime data",
264+
required=True,
265+
)
266+
parser.add_argument(
267+
"--dag_path",
268+
help="Specify path to file containing the DAG, normally a .dot file",
269+
required=True,
255270
)
256271
parser.add_argument(
257-
"--json_path", help="Specify path to file containing JSON tests, normally a .json file", required=True
272+
"--json_path",
273+
help="Specify path to file containing JSON tests, normally a .json file",
274+
required=True,
258275
)
259-
return parser.parse_args()
276+
return parser.parse_args(test_args)

causal_testing/testing/causal_test_engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def __init__(
3636
):
3737
self.causal_test_case = causal_test_case
3838
self.treatment_variables = list(self.causal_test_case.control_input_configuration)
39-
self.casual_dag, self.scenario = causal_specification.causal_dag, causal_specification.scenario
39+
self.casual_dag, self.scenario = (
40+
causal_specification.causal_dag,
41+
causal_specification.scenario,
42+
)
4043
self.data_collector = data_collector
4144
self.scenario_execution_data_df = pd.DataFrame()
4245

causal_testing/testing/estimators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def estimate_unit_ate(self) -> float:
306306
model = self._run_linear_regression()
307307
unit_effect = model.params[list(self.treatment)].values[0] # Unit effect is the coefficient of the treatment
308308
[ci_low, ci_high] = self._get_confidence_intervals(model)
309+
309310
return unit_effect * self.treatment_values - unit_effect * self.control_values, [ci_low, ci_high]
310311

311312
def estimate_ate(self) -> tuple[float, list[float, float], float]:
@@ -437,7 +438,10 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
437438

438439
def _get_confidence_intervals(self, model):
439440
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
440-
ci_low, ci_high = confidence_intervals[0][list(self.treatment)], confidence_intervals[1][list(self.treatment)]
441+
ci_low, ci_high = (
442+
confidence_intervals[0][list(self.treatment)],
443+
confidence_intervals[1][list(self.treatment)],
444+
)
441445
return [ci_low.values[0], ci_high.values[0]]
442446

443447

examples/poisson/run_causal_tests.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ def populate_num_shapes_unit(data):
9292
area = data['width'] * data['height']
9393
data['num_shapes_unit'] = data['num_shapes_abs'] / area
9494

95+
9596
inputs = [
96-
{"name": "width", "type": float, "distribution": "uniform"},
97-
{"name": "height", "type": float, "distribution": "uniform"},
98-
{"name": "intensity", "type": float, "distribution": "uniform"}
97+
{"name": "width", "type": float, "distribution": scipy.stats.uniform(0, 10)},
98+
{"name": "height", "type": float, "distribution": scipy.stats.uniform(0, 10)},
99+
{"name": "intensity", "type": float, "distribution": scipy.stats.uniform(0, 10)}
99100
]
100101

101102
outputs = [
@@ -104,23 +105,13 @@ def populate_num_shapes_unit(data):
104105
]
105106

106107
metas = [
107-
{"name": "num_lines_unit", "type": float, "populate": "populate_num_lines_unit"},
108-
{"name": "num_shapes_unit", "type": float, "populate": "populate_num_shapes_unit"},
109-
{"name": "width_plus_height", "type": float, "populate": "populate_width_height"}
108+
{"name": "num_lines_unit", "type": float, "populate": populate_num_lines_unit},
109+
{"name": "num_shapes_unit", "type": float, "populate": populate_num_shapes_unit},
110+
{"name": "width_plus_height", "type": float, "populate": populate_width_height}
110111
]
111112

112113
constraints = ["width > 0", "height > 0", "intensity > 0"]
113114

114-
populates = {
115-
"populate_width_height": populate_width_height,
116-
"populate_num_lines_unit": populate_num_lines_unit,
117-
"populate_num_shapes_unit": populate_num_shapes_unit
118-
}
119-
120-
distributions = {
121-
"uniform": scipy.stats.uniform(0, 10)
122-
}
123-
124115
effects = {
125116
"PoissonWidthHeight": PoissonWidthHeight(),
126117
"Positive": Positive(),
@@ -136,9 +127,9 @@ def populate_num_shapes_unit(data):
136127
}
137128

138129
# Create input structure required to create a modelling scenario
139-
modelling_inputs = [Input(i['name'], i['type'], distributions[i['distribution']]) for i in inputs] + \
130+
modelling_inputs = [Input(i['name'], i['type'], i['distribution']) for i in inputs] + \
140131
[Output(i['name'], i['type']) for i in outputs] + \
141-
[Meta(i['name'], i['type'], populates[i['populate']]) for i in metas] if metas else []
132+
[Meta(i['name'], i['type'], [i['populate']]) for i in metas] if metas else list()
142133

143134
# Create modelling scenario to access z3 variable mirrors
144135
modelling_scenario = Scenario(modelling_inputs, None)
@@ -173,7 +164,7 @@ def add_modelling_assumptions(self, estimation_model: Estimator):
173164
args.data_path) # Set the path to the data.csv, dag.dot and causal_tests.json file
174165

175166
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
176-
json_utility.set_variables(inputs, outputs, metas, distributions, populates)
167+
json_utility.set_variables(inputs, outputs, metas)
177168
json_utility.setup() # Sets up all the necessary parts of the json_class needed to execute tests
178169

179-
json_utility.execute_tests(effects, mutates, estimators, args.f)
170+
json_utility.generate_tests(effects, mutates, estimators, args.f)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import unittest
2+
from pathlib import Path
3+
import scipy
4+
import csv
5+
import json
6+
7+
from causal_testing.testing.estimators import LinearRegressionEstimator
8+
from causal_testing.testing.causal_test_outcome import NoEffect
9+
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
10+
from causal_testing.json_front.json_class import JsonUtility
11+
from causal_testing.specification.variable import Input, Output, Meta
12+
from causal_testing.specification.scenario import Scenario
13+
from causal_testing.specification.causal_specification import CausalSpecification
14+
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
15+
16+
17+
class TestJsonClass(unittest.TestCase):
18+
"""Test the JSON frontend for the Causal Testing Framework (CTF)
19+
20+
The JSON frontend is an alternative interface for the CTF where tests are specified in JSON format and ingested
21+
with the frontend. Tests involve testing that this correctly interfaces with the framework with some dummy data
22+
"""
23+
24+
def setUp(self) -> None:
25+
json_file_name = "tests.json"
26+
dag_file_name = "dag.dot"
27+
data_file_name = "data.csv"
28+
test_data_dir_path = Path("tests/resources/data")
29+
self.json_path = test_data_dir_path / json_file_name
30+
self.dag_path = test_data_dir_path / dag_file_name
31+
self.data_path = test_data_dir_path / data_file_name
32+
self.json_class = JsonUtility("logs.log")
33+
self.example_distribution = scipy.stats.uniform(1, 10)
34+
self.input_dict_list = [{"name": "test_input", "type": float, "distribution": self.example_distribution}]
35+
self.output_dict_list = [{"name": "test_output", "type": float}]
36+
self.meta_dict_list = [{"name": "test_meta", "type": float, "populate": populate_example}]
37+
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, None)
38+
self.json_class.set_path(self.json_path, self.dag_path, self.data_path)
39+
40+
def test_setting_paths(self):
41+
self.assertEqual(self.json_class.json_path, Path(self.json_path))
42+
self.assertEqual(self.json_class.dag_path, Path(self.dag_path))
43+
self.assertEqual(self.json_class.data_path, Path(self.data_path))
44+
45+
def test_set_inputs(self):
46+
ctf_input = [Input("test_input", float, self.example_distribution)]
47+
self.assertEqual(ctf_input[0].name, self.json_class.inputs[0].name)
48+
self.assertEqual(ctf_input[0].datatype, self.json_class.inputs[0].datatype)
49+
self.assertEqual(ctf_input[0].distribution, self.json_class.inputs[0].distribution)
50+
51+
def test_set_outputs(self):
52+
ctf_output = [Output("test_output", float)]
53+
self.assertEqual(ctf_output[0].name, self.json_class.outputs[0].name)
54+
self.assertEqual(ctf_output[0].datatype, self.json_class.outputs[0].datatype)
55+
56+
def test_set_metas(self):
57+
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, self.meta_dict_list)
58+
ctf_meta = [Meta("test_meta", float, populate_example)]
59+
self.assertEqual(ctf_meta[0].name, self.json_class.metas[0].name)
60+
self.assertEqual(ctf_meta[0].datatype, self.json_class.metas[0].datatype)
61+
62+
def test_argparse(self):
63+
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
64+
self.assertEqual(args.data_path, "data.csv")
65+
self.assertEqual(args.dag_path, "dag.dot")
66+
self.assertEqual(args.json_path, "tests.json")
67+
68+
def test_setup_modelling_scenario(self):
69+
self.json_class.setup()
70+
print(type(self.json_class.modelling_scenario))
71+
print(self.json_class.modelling_scenario)
72+
self.assertIsInstance(self.json_class.modelling_scenario, Scenario)
73+
74+
def test_setup_causal_specification(self):
75+
self.json_class.setup()
76+
self.assertIsInstance(self.json_class.causal_specification, CausalSpecification)
77+
78+
def test_generate_tests_from_json(self):
79+
example_test = {
80+
"tests": [
81+
{
82+
"name": "test1",
83+
"mutations": {"test_input": "Increase"},
84+
"estimator": "LinearRegressionEstimator",
85+
"estimate_type": "ate",
86+
"effect_modifiers": [],
87+
"expectedEffect": {"test_output": "NoEffect"},
88+
"skip": False,
89+
}
90+
]
91+
}
92+
self.json_class.setup()
93+
self.json_class.test_plan = example_test
94+
effects = {"NoEffect": NoEffect()}
95+
mutates = {
96+
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3 >
97+
self.json_class.modelling_scenario.variables[x].z3
98+
}
99+
estimators = {
100+
"LinearRegressionEstimator": LinearRegressionEstimator
101+
}
102+
103+
with self.assertLogs() as captured:
104+
self.json_class.generate_tests(effects, mutates, estimators, False)
105+
106+
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
107+
self.assertIn("failed", captured.records[-1].getMessage())
108+
109+
def tearDown(self) -> None:
110+
pass
111+
#remove_temp_dir_if_existent()
112+
113+
114+
def populate_example(*args, **kwargs):
115+
pass
116+

tests/resources/data/dag.dot

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
digraph G { test_input -> B; B -> C; test_output -> test_input; test_output -> C}

tests/resources/data/data.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
index,test_input,test_output
2+
0,1,2

tests/resources/data/tests.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"tests": [{"name": "test1", "mutations": {}, "estimator": null, "estimate_type": null, "effect_modifiers": [], "expectedEffect": {}, "skip": false}]}

0 commit comments

Comments
 (0)