Skip to content

Commit aaf49ed

Browse files
committed
coverage
1 parent d8e2a40 commit aaf49ed

File tree

6 files changed

+196
-18
lines changed

6 files changed

+196
-18
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from causal_testing.testing.causal_test_suite import CausalTestSuite
99
from causal_testing.data_collection.data_collector import DataCollector
10-
from causal_testing.specification.causal_specification import CausalSpecification
10+
from causal_testing.specification.causal_dag import CausalDAG
1111
from causal_testing.testing.estimators import Estimator
1212
from causal_testing.testing.causal_test_case import CausalTestCase
1313

@@ -19,10 +19,10 @@ class DAGAdequacy:
1919

2020
def __init__(
2121
self,
22-
causal_specification: CausalSpecification,
22+
causal_dag: CausalDAG,
2323
test_suite: CausalTestSuite,
2424
):
25-
self.causal_dag = causal_specification.causal_dag
25+
self.causal_dag = causal_dag
2626
self.test_suite = test_suite
2727
self.tested_pairs = None
2828
self.pairs_to_test = None
@@ -33,9 +33,7 @@ def measure_adequacy(self):
3333
"""
3434
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
3535
"""
36-
self.tested_pairs = {
37-
(t.base_test_case.treatment_variable, t.base_test_case.outcome_variable) for t in self.test_suite
38-
}
36+
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite}
3937
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
4038
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
4139
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)

causal_testing/testing/causal_test_case.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ def execute_test(self, estimator: type(Estimator), data_collector: DataCollector
8383
if estimator.df is None:
8484
estimator.df = data_collector.collect_data()
8585

86-
logger.info("treatments: %s", self.treatment_variable.name)
87-
logger.info("outcomes: %s", self.outcome_variable)
88-
8986
causal_test_result = self._return_causal_test_results(estimator)
9087
return causal_test_result
9188

tests/json_front_tests/test_json_class.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from statistics import StatisticsError
44
import scipy
5+
import os
56

67
from causal_testing.testing.estimators import LinearRegressionEstimator
78
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
@@ -136,9 +137,10 @@ def test_run_json_tests_from_json(self):
136137
"name": "test1",
137138
"mutations": {"test_input": "Increase"},
138139
"estimator": "LinearRegressionEstimator",
139-
"estimate_type": "ate",
140+
"estimate_type": "coefficient",
140141
"effect_modifiers": [],
141142
"expected_effect": {"test_output": "NoEffect"},
143+
"coverage": True,
142144
"skip": False,
143145
}
144146
]
@@ -151,13 +153,10 @@ def test_run_json_tests_from_json(self):
151153
}
152154
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
153155

154-
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False, mutates=mutates)
155-
156-
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this
157-
# scenario
158-
with open("temp_out.txt", "r") as reader:
159-
temp_out = reader.readlines()
160-
self.assertIn("failed", temp_out[-1])
156+
test_results = self.json_class.run_json_tests(
157+
effects=effects, estimators=estimators, f_flag=False, mutates=mutates
158+
)
159+
self.assertTrue(test_results[0]["failed"])
161160

162161
def test_generate_tests_from_json_no_dist(self):
163162
example_test = {
@@ -294,6 +293,8 @@ def test_no_data_provided(self):
294293

295294
def tearDown(self) -> None:
296295
remove_temp_dir_if_existent()
296+
if os.path.exists("temp_out.txt"):
297+
os.remove("temp_out.txt")
297298

298299

299300
def populate_example(*args, **kwargs):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
index,test_input,test_input_no_dist,test_output
2+
0,1,a,2
3+
1,2,b,2
4+
2,3,a,2
5+
3,4,b,2
6+
4,5,a,2
7+
5,6,b,2
8+
6,7,a,2
9+
7,8,b,2

tests/resources/data/tests.json

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
{"tests": [{"name": "test1", "mutations": {}, "estimator": null, "estimate_type": null, "effect_modifiers": [], "expected_effect": {}, "skip": false}]}
1+
{
2+
"tests": [{
3+
"name": "test1",
4+
"mutations": {},
5+
"estimator": null,
6+
"estimate_type": null,
7+
"effect_modifiers": [],
8+
"expected_effect": {},
9+
"skip": false
10+
}]
11+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import unittest
2+
from pathlib import Path
3+
from statistics import StatisticsError
4+
import scipy
5+
import os
6+
7+
from causal_testing.testing.estimators import LinearRegressionEstimator
8+
from causal_testing.testing.base_test_case import BaseTestCase
9+
from causal_testing.testing.causal_test_case import CausalTestCase
10+
from causal_testing.testing.causal_test_suite import CausalTestSuite
11+
from causal_testing.testing.causal_test_adequacy import DAGAdequacy
12+
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
13+
from tests.test_helpers import remove_temp_dir_if_existent
14+
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
15+
from causal_testing.specification.variable import Input, Output, Meta
16+
from causal_testing.specification.scenario import Scenario
17+
from causal_testing.specification.causal_specification import CausalSpecification
18+
19+
20+
class TestJsonClass(unittest.TestCase):
21+
"""Test the JSON frontend for the Causal Testing Framework (CTF)
22+
23+
The JSON frontend is an alternative interface for the CTF where tests are specified in JSON format and ingested
24+
with the frontend. Tests involve testing that this correctly interfaces with the framework with some dummy data
25+
"""
26+
27+
def setUp(self) -> None:
28+
json_file_name = "tests.json"
29+
dag_file_name = "dag.dot"
30+
data_file_name = "data_with_categorical.csv"
31+
test_data_dir_path = Path("tests/resources/data")
32+
self.json_path = str(test_data_dir_path / json_file_name)
33+
self.dag_path = str(test_data_dir_path / dag_file_name)
34+
self.data_path = [str(test_data_dir_path / data_file_name)]
35+
self.json_class = JsonUtility("temp_out.txt", True)
36+
self.example_distribution = scipy.stats.uniform(1, 10)
37+
self.input_dict_list = [
38+
{"name": "test_input", "datatype": float, "distribution": self.example_distribution},
39+
{"name": "test_input_no_dist", "datatype": float},
40+
]
41+
self.output_dict_list = [{"name": "test_output", "datatype": float}]
42+
variables = CausalVariables(inputs=self.input_dict_list, outputs=self.output_dict_list, metas=[])
43+
self.scenario = Scenario(variables=variables, constraints=None)
44+
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
45+
self.json_class.setup(self.scenario)
46+
47+
def test_data_adequacy_numeric(self):
48+
example_test = {
49+
"tests": [
50+
{
51+
"name": "test1",
52+
"mutations": {"test_input": "Increase"},
53+
"estimator": "LinearRegressionEstimator",
54+
"estimate_type": "coefficient",
55+
"effect_modifiers": [],
56+
"expected_effect": {"test_output": "NoEffect"},
57+
"coverage": True,
58+
"skip": False,
59+
}
60+
]
61+
}
62+
self.json_class.test_plan = example_test
63+
effects = {"NoEffect": NoEffect()}
64+
mutates = {
65+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
66+
> self.json_class.scenario.variables[x].z3
67+
}
68+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
69+
70+
test_results = self.json_class.run_json_tests(
71+
effects=effects, estimators=estimators, f_flag=False, mutates=mutates
72+
)
73+
self.assertEqual(
74+
test_results[0]["result"].adequacy.to_dict(),
75+
{"kurtosis": {"test_input": 0.0}, "bootstrap_size": 100, "passing": 100},
76+
)
77+
78+
def test_data_adequacy_cateogorical(self):
79+
example_test = {
80+
"tests": [
81+
{
82+
"name": "test1",
83+
"mutations": ["test_input_no_dist"],
84+
"estimator": "LinearRegressionEstimator",
85+
"estimate_type": "coefficient",
86+
"effect_modifiers": [],
87+
"expected_effect": {"test_output": "NoEffect"},
88+
"coverage": True,
89+
"skip": False,
90+
}
91+
]
92+
}
93+
self.json_class.test_plan = example_test
94+
effects = {"NoEffect": NoEffect()}
95+
mutates = {
96+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
97+
> self.json_class.scenario.variables[x].z3
98+
}
99+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
100+
101+
test_results = self.json_class.run_json_tests(
102+
effects=effects, estimators=estimators, f_flag=False, mutates=mutates
103+
)
104+
print("RESULT")
105+
print(test_results[0]["result"])
106+
self.assertEqual(
107+
test_results[0]["result"].adequacy.to_dict(),
108+
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100},
109+
)
110+
111+
def test_dag_adequacy(self):
112+
base_test_case = BaseTestCase(
113+
treatment_variable="test_input",
114+
outcome_variable="test_output",
115+
effect=None,
116+
)
117+
causal_test_case = CausalTestCase(
118+
base_test_case=base_test_case,
119+
expected_causal_effect=None,
120+
estimate_type=None,
121+
)
122+
test_suite = CausalTestSuite()
123+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
124+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
125+
dag_adequacy.measure_adequacy()
126+
print(dag_adequacy.to_dict())
127+
self.assertEqual(
128+
dag_adequacy.to_dict(),
129+
{
130+
"causal_dag": self.json_class.causal_specification.causal_dag,
131+
"test_suite": test_suite,
132+
"tested_pairs": {("test_input", "test_output")},
133+
"pairs_to_test": {
134+
("test_input_no_dist", "test_input"),
135+
("test_input", "B"),
136+
("test_input_no_dist", "C"),
137+
("test_input_no_dist", "B"),
138+
("test_input", "C"),
139+
("B", "C"),
140+
("test_input_no_dist", "test_output"),
141+
("test_input", "test_output"),
142+
("C", "test_output"),
143+
("B", "test_output"),
144+
},
145+
"untested_edges": {
146+
("test_input_no_dist", "test_input"),
147+
("test_input", "B"),
148+
("test_input_no_dist", "C"),
149+
("test_input_no_dist", "B"),
150+
("test_input", "C"),
151+
("B", "C"),
152+
("test_input_no_dist", "test_output"),
153+
("C", "test_output"),
154+
("B", "test_output"),
155+
},
156+
"dag_adequacy": 0.1,
157+
},
158+
)
159+
160+
def tearDown(self) -> None:
161+
remove_temp_dir_if_existent()
162+
if os.path.exists("temp_out.txt"):
163+
os.remove("temp_out.txt")

0 commit comments

Comments
 (0)