Skip to content

Commit 8a4d42f

Browse files
committed
Causal DAG adequacy
1 parent 192d6e4 commit 8a4d42f

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

causal_testing/json_front/json_class.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
129129
test["estimator"] = estimators[test["estimator"]]
130130
if "mutations" in test:
131131
if test["estimate_type"] == "coefficient":
132-
msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
132+
failed, msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
133133
else:
134-
msg = self._run_ate_test(test=test, f_flag=f_flag, effects=effects, mutates=mutates)
134+
failed, msg = self._run_ate_test(test=test, f_flag=f_flag, effects=effects, mutates=mutates)
135135
self._append_to_file(msg, logging.INFO)
136136
else:
137137
outcome_variable = next(
@@ -158,8 +158,10 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
158158
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
159159
+ f"Result: {'FAILED' if failed else 'Passed'}"
160160
)
161-
print(msg)
162161
self._append_to_file(msg, logging.INFO)
162+
test["failed"] = failed
163+
print(msg)
164+
return self.test_plan["tests"]
163165

164166
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
165167
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -181,16 +183,16 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
181183
estimate_type="coefficient",
182184
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
183185
)
184-
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
186+
failed, result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
185187
msg = (
186188
f"Executing test: {test['name']} \n"
187189
+ f" {causal_test_case} \n"
188190
+ " "
189-
+ ("\n ").join(str(result[1]).split("\n"))
191+
+ ("\n ").join(str(result).split("\n"))
190192
+ "==============\n"
191-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
193+
+ f" Result: {'FAILED' if failed else 'Passed'}"
192194
)
193-
return msg
195+
return failed, msg
194196

195197
def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
196198
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
@@ -224,7 +226,7 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
224226
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
225227
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
226228
)
227-
return msg
229+
return failures, msg
228230

229231
def _execute_tests(self, concrete_tests, test, f_flag):
230232
failures = 0
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
This module contains code to measure various aspects of causal test adequacy.
3+
"""
4+
from causal_testing.testing.causal_test_suite import CausalTestSuite
5+
from causal_testing.specification.causal_specification import CausalSpecification
6+
from causal_testing.testing.estimators import Estimator
7+
from causal_testing.testing.causal_test_case import CausalTestCase
8+
from itertools import combinations
9+
10+
11+
class CausalTestAdequacy:
12+
def __init__(
13+
self,
14+
causal_specification: CausalSpecification,
15+
test_suite: CausalTestSuite,
16+
):
17+
self.causal_dag, self.scenario = (
18+
causal_specification.causal_dag,
19+
causal_specification.scenario,
20+
)
21+
self.test_suite = test_suite
22+
self.estimator = estimator
23+
self.dag_adequacy = DAGAdequacy(causal_specification, test_suite)
24+
25+
26+
class DAGAdequacy:
27+
def __init__(
28+
self,
29+
causal_specification: CausalSpecification,
30+
test_suite: CausalTestSuite,
31+
):
32+
self.causal_dag = causal_specification.causal_dag
33+
self.test_suite = test_suite
34+
self.tested_pairs = {
35+
(t.base_test_case.treatment_variable, t.base_test_case.outcome_variable) for t in self.causal_test_suite
36+
}
37+
self.pairs_to_test = set(combinations(dag.graph.nodes, 2))
38+
self.untested_edges = pairs_to_test.difference(tested_pairs)
39+
self.dag_adequacy = len(tested_pairs) / len(pairs_to_test)

0 commit comments

Comments
 (0)