Skip to content

Commit 74fc22d

Browse files
committed
extra tests
1 parent 4a8c7b0 commit 74fc22d

File tree

6 files changed

+55
-31
lines changed

6 files changed

+55
-31
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _generate_concrete_tests(
131131
)
132132

133133
for v in self.scenario.inputs():
134-
if row[v.name] != v.cast(model[v.z3]):
134+
if v.name in row and row[v.name] != v.cast(model[v.z3]):
135135
constraints = "\n ".join([str(c) for c in self.scenario.constraints if v.name in str(c)])
136136
logger.warning(
137137
f"Unable to set variable {v.name} to {row[v.name]} because of constraints\n"

tests/generation_tests/test_abstract_test_case.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
import numpy as np
55
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
6+
from causal_testing.generation.enum_gen import EnumGen
67
from causal_testing.specification.causal_specification import Scenario
78
from causal_testing.specification.variable import Input, Output
89
from scipy.stats import uniform, rv_discrete
@@ -22,17 +23,6 @@ def __gt__(self, other):
2223
return NotImplemented
2324

2425

25-
class CarGen(rv_discrete):
26-
cars = dict(enumerate(Car, 1))
27-
inverse_cars = {v: k for k, v in cars.items()}
28-
29-
def ppf(self, q, *args, **kwds):
30-
return np.vectorize(self.cars.get)(np.ceil(len(self.cars) * q))
31-
32-
def cdf(self, q, *args, **kwds):
33-
return np.vectorize(self.inverse_cars.get)(q) / len(Car)
34-
35-
3626
class TestAbstractTestCase(unittest.TestCase):
3727
"""
3828
Class to test abstract test cases.
@@ -53,7 +43,7 @@ def setUp(self) -> None:
5343
self.X3 = Input("X3", float, uniform(10, 40))
5444
self.X4 = Input("X4", int, rv_discrete(values=([10], [1])))
5545
self.X5 = Input("X5", bool, rv_discrete(values=(range(2), [0.5, 0.5])))
56-
self.Car = Input("Car", Car, CarGen())
46+
self.Car = Input("Car", Car, EnumGen(Car))
5747
self.Y = Output("Y", int)
5848

5949
def test_generate_concrete_test_cases(self):

tests/json_front_tests/test_json_class.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ def setUp(self) -> None:
3232
self.data_path = [str(test_data_dir_path / data_file_name)]
3333
self.json_class = JsonUtility("temp_out.txt", True)
3434
self.example_distribution = scipy.stats.uniform(1, 10)
35-
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
35+
self.input_dict_list = [
36+
{"name": "test_input", "datatype": float, "distribution": self.example_distribution},
37+
{"name": "test_input_no_dist", "datatype": float},
38+
]
3639
self.output_dict_list = [{"name": "test_output", "datatype": float}]
3740
self.meta_dict_list = [{"name": "test_meta", "datatype": float, "populate": populate_example}]
38-
variables = CausalVariables(inputs=self.input_dict_list, outputs=self.output_dict_list,
39-
metas=self.meta_dict_list)
41+
variables = CausalVariables(
42+
inputs=self.input_dict_list, outputs=self.output_dict_list, metas=self.meta_dict_list
43+
)
4044
self.scenario = Scenario(variables=variables, constraints=None)
4145
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
4246
self.json_class.setup(self.scenario)
@@ -48,19 +52,19 @@ def test_setting_paths(self):
4852

4953
def test_set_inputs(self):
5054
ctf_input = [Input("test_input", float, self.example_distribution)]
51-
self.assertEqual(ctf_input[0].name, self.json_class.scenario.variables['test_input'].name)
52-
self.assertEqual(ctf_input[0].datatype, self.json_class.scenario.variables['test_input'].datatype)
53-
self.assertEqual(ctf_input[0].distribution, self.json_class.scenario.variables['test_input'].distribution)
55+
self.assertEqual(ctf_input[0].name, self.json_class.scenario.variables["test_input"].name)
56+
self.assertEqual(ctf_input[0].datatype, self.json_class.scenario.variables["test_input"].datatype)
57+
self.assertEqual(ctf_input[0].distribution, self.json_class.scenario.variables["test_input"].distribution)
5458

5559
def test_set_outputs(self):
5660
ctf_output = [Output("test_output", float)]
57-
self.assertEqual(ctf_output[0].name, self.json_class.scenario.variables['test_output'].name)
58-
self.assertEqual(ctf_output[0].datatype, self.json_class.scenario.variables['test_output'].datatype)
61+
self.assertEqual(ctf_output[0].name, self.json_class.scenario.variables["test_output"].name)
62+
self.assertEqual(ctf_output[0].datatype, self.json_class.scenario.variables["test_output"].datatype)
5963

6064
def test_set_metas(self):
6165
ctf_meta = [Meta("test_meta", float, populate_example)]
62-
self.assertEqual(ctf_meta[0].name, self.json_class.scenario.variables['test_meta'].name)
63-
self.assertEqual(ctf_meta[0].datatype, self.json_class.scenario.variables['test_meta'].datatype)
66+
self.assertEqual(ctf_meta[0].name, self.json_class.scenario.variables["test_meta"].name)
67+
self.assertEqual(ctf_meta[0].datatype, self.json_class.scenario.variables["test_meta"].datatype)
6468

6569
def test_argparse(self):
6670
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
@@ -92,7 +96,7 @@ def test_f_flag(self):
9296
effects = {"NoEffect": NoEffect()}
9397
mutates = {
9498
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
95-
> self.json_class.scenario.variables[x].z3
99+
> self.json_class.scenario.variables[x].z3
96100
}
97101
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
98102
with self.assertRaises(StatisticsError):
@@ -116,14 +120,44 @@ def test_generate_tests_from_json(self):
116120
effects = {"NoEffect": NoEffect()}
117121
mutates = {
118122
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
119-
> self.json_class.scenario.variables[x].z3
123+
> self.json_class.scenario.variables[x].z3
120124
}
121125
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
122126

123127
self.json_class.generate_tests(effects, mutates, estimators, False)
124128

125129
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
126-
with open("temp_out.txt", 'r') as reader:
130+
with open("temp_out.txt", "r") as reader:
131+
temp_out = reader.readlines()
132+
self.assertIn("failed", temp_out[-1])
133+
134+
135+
def test_generate_tests_from_json_no_dist(self):
136+
example_test = {
137+
"tests": [
138+
{
139+
"name": "test1",
140+
"mutations": {"test_input_no_dist": "Increase"},
141+
"estimator": "LinearRegressionEstimator",
142+
"estimate_type": "ate",
143+
"effect_modifiers": [],
144+
"expectedEffect": {"test_output": "NoEffect"},
145+
"skip": False,
146+
}
147+
]
148+
}
149+
self.json_class.test_plan = example_test
150+
effects = {"NoEffect": NoEffect()}
151+
mutates = {
152+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
153+
> self.json_class.scenario.variables[x].z3
154+
}
155+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
156+
157+
self.json_class.generate_tests(effects, mutates, estimators, False)
158+
159+
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
160+
with open("temp_out.txt", "r") as reader:
127161
temp_out = reader.readlines()
128162
self.assertIn("failed", temp_out[-1])
129163

tests/resources/data/dag.dot

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
digraph G { test_input -> B; B -> C; test_output -> test_input; test_output -> C}
1+
digraph G { test_input_no_dist; est_input -> B; B -> C; test_output -> test_input; test_output -> C}

tests/resources/data/data.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
index,test_input,test_output
2-
0,1,2
1+
index,test_input,test_input_no_dist,test_output
2+
0,1,1,2
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
index,test_input,test_output,test_meta
2-
0,1,2,3
1+
index,test_input,test_input_no_dist,test_output,test_meta
2+
0,1,1,2,3

0 commit comments

Comments
 (0)