Skip to content

Commit f87238d

Browse files
Update tests
1 parent 90429ce commit f87238d

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

tests/json_front_tests/test_json_class.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from causal_testing.testing.estimators import LinearRegressionEstimator
88
from causal_testing.testing.causal_test_outcome import NoEffect
99
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
10-
from causal_testing.json_front.json_class import JsonUtility
10+
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1111
from causal_testing.specification.variable import Input, Output, Meta
1212
from causal_testing.specification.scenario import Scenario
1313
from causal_testing.specification.causal_specification import CausalSpecification
@@ -24,7 +24,7 @@ class TestJsonClass(unittest.TestCase):
2424
def setUp(self) -> None:
2525
json_file_name = "tests.json"
2626
dag_file_name = "dag.dot"
27-
data_file_name = "data.csv"
27+
data_file_name = "data_with_meta.csv"
2828
test_data_dir_path = Path("tests/resources/data")
2929
self.json_path = str(test_data_dir_path / json_file_name)
3030
self.dag_path = str(test_data_dir_path / dag_file_name)
@@ -34,8 +34,11 @@ def setUp(self) -> None:
3434
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
3535
self.output_dict_list = [{"name": "test_output", "datatype": float}]
3636
self.meta_dict_list = [{"name": "test_meta", "datatype": float, "populate": populate_example}]
37-
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, None)
37+
variables = CausalVariables(inputs=self.input_dict_list, outputs=self.output_dict_list,
38+
metas=self.meta_dict_list)
39+
self.scenario = Scenario(variables=variables, constraints=None)
3840
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
41+
self.json_class.setup(self.scenario)
3942

4043
def test_setting_paths(self):
4144
self.assertEqual(self.json_class.paths.json_path, Path(self.json_path))
@@ -44,33 +47,30 @@ def test_setting_paths(self):
4447

4548
def test_set_inputs(self):
4649
ctf_input = [Input("test_input", float, self.example_distribution)]
47-
self.assertEqual(ctf_input[0].name, self.json_class.variables.inputs[0].name)
48-
self.assertEqual(ctf_input[0].datatype, self.json_class.variables.inputs[0].datatype)
49-
self.assertEqual(ctf_input[0].distribution, self.json_class.variables.inputs[0].distribution)
50+
self.assertEqual(ctf_input[0].name, self.json_class.scenario.variables['test_input'].name)
51+
self.assertEqual(ctf_input[0].datatype, self.json_class.scenario.variables['test_input'].datatype)
52+
self.assertEqual(ctf_input[0].distribution, self.json_class.scenario.variables['test_input'].distribution)
5053

5154
def test_set_outputs(self):
5255
ctf_output = [Output("test_output", float)]
53-
self.assertEqual(ctf_output[0].name, self.json_class.variables.outputs[0].name)
54-
self.assertEqual(ctf_output[0].datatype, self.json_class.variables.outputs[0].datatype)
56+
self.assertEqual(ctf_output[0].name, self.json_class.scenario.variables['test_output'].name)
57+
self.assertEqual(ctf_output[0].datatype, self.json_class.scenario.variables['test_output'].datatype)
5558

5659
def test_set_metas(self):
57-
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, self.meta_dict_list)
5860
ctf_meta = [Meta("test_meta", float, populate_example)]
59-
self.assertEqual(ctf_meta[0].name, self.json_class.variables.metas[0].name)
60-
self.assertEqual(ctf_meta[0].datatype, self.json_class.variables.metas[0].datatype)
61+
self.assertEqual(ctf_meta[0].name, self.json_class.scenario.variables['test_meta'].name)
62+
self.assertEqual(ctf_meta[0].datatype, self.json_class.scenario.variables['test_meta'].datatype)
6163

6264
def test_argparse(self):
6365
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
6466
self.assertEqual(args.data_path, ["data.csv"])
6567
self.assertEqual(args.dag_path, "dag.dot")
6668
self.assertEqual(args.json_path, "tests.json")
6769

68-
def test_setup_modelling_scenario(self):
69-
self.json_class.setup()
70-
self.assertIsInstance(self.json_class.modelling_scenario, Scenario)
70+
def test_setup_scenario(self):
71+
self.assertIsInstance(self.json_class.scenario, Scenario)
7172

7273
def test_setup_causal_specification(self):
73-
self.json_class.setup()
7474
self.assertIsInstance(self.json_class.causal_specification, CausalSpecification)
7575

7676
def test_generate_tests_from_json(self):
@@ -87,12 +87,11 @@ def test_generate_tests_from_json(self):
8787
}
8888
]
8989
}
90-
self.json_class.setup()
9190
self.json_class.test_plan = example_test
9291
effects = {"NoEffect": NoEffect()}
9392
mutates = {
94-
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3
95-
> self.json_class.modelling_scenario.variables[x].z3
93+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
94+
> self.json_class.scenario.variables[x].z3
9695
}
9796
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
9897

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
index,test_input,test_output,test_meta
2+
0,1,2,3

0 commit comments

Comments
 (0)