Skip to content

Commit 7e01033

Browse files
Merge pull request #170 from CITCOM-project/JSON_frontend_issues
remove set_variables from json frontend Closes #115.
2 parents 626e10e + 83584a7 commit 7e01033

File tree

5 files changed

+41
-54
lines changed

5 files changed

+41
-54
lines changed

causal_testing/json_front/json_class.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import json
66
import logging
77

8-
from abc import ABC
98
from dataclasses import dataclass
109
from pathlib import Path
1110

@@ -26,7 +25,7 @@
2625
logger = logging.getLogger(__name__)
2726

2827

29-
class JsonUtility(ABC):
28+
class JsonUtility:
3029
"""
3130
The JsonUtility Class provides the functionality to use structured JSON to setup and run causal tests on the
3231
CausalTestingFramework.
@@ -48,7 +47,7 @@ def __init__(self, log_path):
4847
self.variables = None
4948
self.data = []
5049
self.test_plan = None
51-
self.modelling_scenario = None
50+
self.scenario = None
5251
self.causal_specification = None
5352
self.setup_logger(log_path)
5453

@@ -61,36 +60,27 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: str):
6160
"""
6261
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6362

64-
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
65-
"""Populate the Causal Variables
66-
:param inputs:
67-
:param outputs:
68-
:param metas:
69-
"""
70-
71-
self.variables = CausalVariables(inputs=inputs, outputs=outputs, metas=metas)
72-
73-
def setup(self):
63+
def setup(self, scenario: Scenario):
7464
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
75-
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
76-
self.modelling_scenario.setup_treatment_variables()
65+
self.scenario = scenario
66+
self.scenario.setup_treatment_variables()
7767
self.causal_specification = CausalSpecification(
78-
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
68+
scenario=self.scenario, causal_dag=CausalDAG(self.paths.dag_path)
7969
)
8070
self._json_parse()
8171
self._populate_metas()
8272

8373
def _create_abstract_test_case(self, test, mutates, effects):
8474
assert len(test["mutations"]) == 1
8575
abstract_test = AbstractCausalTestCase(
86-
scenario=self.modelling_scenario,
76+
scenario=self.scenario,
8777
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
88-
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
78+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
8979
expected_causal_effect={
90-
self.modelling_scenario.variables[variable]: effects[effect]
80+
self.scenario.variables[variable]: effects[effect]
9181
for variable, effect in test["expectedEffect"].items()
9282
},
93-
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
83+
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
9484
if "effect_modifiers" in test
9585
else {},
9686
estimate_type=test["estimate_type"],
@@ -141,10 +131,9 @@ def _populate_metas(self):
141131
"""
142132
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
143133
"""
144-
for meta in self.variables.metas:
134+
for meta in self.scenario.variables_of_type(Meta):
145135
meta.populate(self.data)
146-
147-
for var in self.variables.metas + self.variables.outputs:
136+
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
148137
if not var.distribution:
149138
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
150139
fitter.fit()
@@ -195,7 +184,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
195184
- estimation_model - Estimator instance for the test being run
196185
"""
197186

198-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data)
187+
data_collector = ObservationalDataCollector(self.scenario, self.data)
199188
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
200189

201190
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
@@ -289,17 +278,17 @@ def __init__(self, json_path: str, dag_path: str, data_paths: str):
289278
self.data_paths = [Path(path) for path in data_paths]
290279

291280

292-
@dataclass()
281+
@dataclass
293282
class CausalVariables:
294283
"""
295-
A dataclass that converts
284+
A dataclass that converts lists of dictionaries into lists of Causal Variables
296285
"""
297286

298-
inputs: list[Input]
299-
outputs: list[Output]
300-
metas: list[Meta]
301-
302287
def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
303288
self.inputs = [Input(**i) for i in inputs]
304289
self.outputs = [Output(**o) for o in outputs]
305290
self.metas = [Meta(**m) for m in metas] if metas else []
291+
292+
def __iter__(self):
293+
for var in self.inputs + self.outputs + self.metas:
294+
yield var

causal_testing/specification/metamorphic_relation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
191191

192192
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
193193
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
194-
195194
# Case 1: U --> ... --> V
196195
if u in nx.ancestors(dag.graph, v):
197196
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])

examples/poisson/example_run_causal_tests.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def populate_num_shapes_unit(data):
135135
modelling_inputs = (
136136
[Input(i["name"], i["datatype"], i["distribution"]) for i in inputs]
137137
+ [Output(i["name"], i["datatype"]) for i in outputs]
138-
+ ([Meta(i["name"], i["datatype"], [i["populate"]]) for i in metas] if metas else list())
138+
+ ([Meta(i["name"], i["datatype"], i["populate"]) for i in metas] if metas else list())
139139
)
140140

141141
# Create modelling scenario to access z3 variable mirrors
@@ -172,8 +172,7 @@ def test_run_causal_tests():
172172
) # Set the path to the data.csv, dag.dot and causal_tests.json file
173173

174174
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
175-
json_utility.set_variables(inputs, outputs, metas)
176-
json_utility.setup() # Sets up all the necessary parts of the json_class needed to execute tests
175+
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
177176

178177
json_utility.generate_tests(effects, mutates, estimators, False)
179178

@@ -186,7 +185,6 @@ def test_run_causal_tests():
186185
) # Set the path to the data.csv, dag.dot and causal_tests.json file
187186

188187
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
189-
json_utility.set_variables(inputs, outputs, metas)
190-
json_utility.setup() # Sets up all the necessary parts of the json_class needed to execute tests
188+
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
191189

192190
json_utility.generate_tests(effects, mutates, estimators, args.f)

tests/json_front_tests/test_json_class.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from causal_testing.testing.estimators import LinearRegressionEstimator
88
from causal_testing.testing.causal_test_outcome import NoEffect
99
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
10-
from causal_testing.json_front.json_class import JsonUtility
10+
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1111
from causal_testing.specification.variable import Input, Output, Meta
1212
from causal_testing.specification.scenario import Scenario
1313
from causal_testing.specification.causal_specification import CausalSpecification
@@ -24,7 +24,7 @@ class TestJsonClass(unittest.TestCase):
2424
def setUp(self) -> None:
2525
json_file_name = "tests.json"
2626
dag_file_name = "dag.dot"
27-
data_file_name = "data.csv"
27+
data_file_name = "data_with_meta.csv"
2828
test_data_dir_path = Path("tests/resources/data")
2929
self.json_path = str(test_data_dir_path / json_file_name)
3030
self.dag_path = str(test_data_dir_path / dag_file_name)
@@ -34,8 +34,11 @@ def setUp(self) -> None:
3434
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
3535
self.output_dict_list = [{"name": "test_output", "datatype": float}]
3636
self.meta_dict_list = [{"name": "test_meta", "datatype": float, "populate": populate_example}]
37-
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, None)
37+
variables = CausalVariables(inputs=self.input_dict_list, outputs=self.output_dict_list,
38+
metas=self.meta_dict_list)
39+
self.scenario = Scenario(variables=variables, constraints=None)
3840
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
41+
self.json_class.setup(self.scenario)
3942

4043
def test_setting_paths(self):
4144
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
@@ -44,33 +47,30 @@ def test_setting_paths(self):
4447

4548
def test_set_inputs(self):
4649
ctf_input = [Input("test_input", float, self.example_distribution)]
47-
self.assertEqual(ctf_input[0].name, self.json_class.variables.inputs[0].name)
48-
self.assertEqual(ctf_input[0].datatype, self.json_class.variables.inputs[0].datatype)
49-
self.assertEqual(ctf_input[0].distribution, self.json_class.variables.inputs[0].distribution)
50+
self.assertEqual(ctf_input[0].name, self.json_class.scenario.variables['test_input'].name)
51+
self.assertEqual(ctf_input[0].datatype, self.json_class.scenario.variables['test_input'].datatype)
52+
self.assertEqual(ctf_input[0].distribution, self.json_class.scenario.variables['test_input'].distribution)
5053

5154
def test_set_outputs(self):
5255
ctf_output = [Output("test_output", float)]
53-
self.assertEqual(ctf_output[0].name, self.json_class.variables.outputs[0].name)
54-
self.assertEqual(ctf_output[0].datatype, self.json_class.variables.outputs[0].datatype)
56+
self.assertEqual(ctf_output[0].name, self.json_class.scenario.variables['test_output'].name)
57+
self.assertEqual(ctf_output[0].datatype, self.json_class.scenario.variables['test_output'].datatype)
5558

5659
def test_set_metas(self):
57-
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, self.meta_dict_list)
5860
ctf_meta = [Meta("test_meta", float, populate_example)]
59-
self.assertEqual(ctf_meta[0].name, self.json_class.variables.metas[0].name)
60-
self.assertEqual(ctf_meta[0].datatype, self.json_class.variables.metas[0].datatype)
61+
self.assertEqual(ctf_meta[0].name, self.json_class.scenario.variables['test_meta'].name)
62+
self.assertEqual(ctf_meta[0].datatype, self.json_class.scenario.variables['test_meta'].datatype)
6163

6264
def test_argparse(self):
6365
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
6466
self.assertEqual(args.data_path, ["data.csv"])
6567
self.assertEqual(args.dag_path, "dag.dot")
6668
self.assertEqual(args.json_path, "tests.json")
6769

68-
def test_setup_modelling_scenario(self):
69-
self.json_class.setup()
70-
self.assertIsInstance(self.json_class.modelling_scenario, Scenario)
70+
def test_setup_scenario(self):
71+
self.assertIsInstance(self.json_class.scenario, Scenario)
7172

7273
def test_setup_causal_specification(self):
73-
self.json_class.setup()
7474
self.assertIsInstance(self.json_class.causal_specification, CausalSpecification)
7575

7676
def test_generate_tests_from_json(self):
@@ -87,12 +87,11 @@ def test_generate_tests_from_json(self):
8787
}
8888
]
8989
}
90-
self.json_class.setup()
9190
self.json_class.test_plan = example_test
9291
effects = {"NoEffect": NoEffect()}
9392
mutates = {
94-
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3
95-
> self.json_class.modelling_scenario.variables[x].z3
93+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
94+
> self.json_class.scenario.variables[x].z3
9695
}
9796
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
9897

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
index,test_input,test_output,test_meta
2+
0,1,2,3

0 commit comments

Comments
 (0)