Skip to content

Commit 0776b79

Browse files
authored
Merge pull request #226 from CITCOM-project/test-adequacy
Fixed order-dependency bug in independence test DAG adequacy
2 parents bf0bbe3 + 8bb6648 commit 0776b79

File tree

2 files changed

+121
-13
lines changed

2 files changed

+121
-13
lines changed

causal_testing/testing/causal_test_adequacy.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,26 @@ def __init__(
2626
self.test_suite = test_suite
2727
self.tested_pairs = None
2828
self.pairs_to_test = None
29-
self.untested_edges = None
29+
self.untested_pairs = None
3030
self.dag_adequacy = None
3131

3232
def measure_adequacy(self):
3333
"""
34-
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
34+
Calculate the adequacy measurement, and populate the `dag_adequacy` field.
3535
"""
36-
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite}
37-
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
38-
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
36+
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes(), 2))
37+
self.tested_pairs = set()
38+
39+
for n1, n2 in self.pairs_to_test:
40+
if (n1, n2) in self.causal_dag.graph.edges():
41+
if any((t.treatment_variable, t.outcome_variable) == (n1, n2) for t in self.test_suite):
42+
self.tested_pairs.add((n1, n2))
43+
else:
44+
# Causal independences are not order dependent
45+
if any((t.treatment_variable, t.outcome_variable) in {(n1, n2), (n2, n1)} for t in self.test_suite):
46+
self.tested_pairs.add((n1, n2))
47+
48+
self.untested_pairs = self.pairs_to_test.difference(self.tested_pairs)
3949
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)
4050

4151
def to_dict(self):
@@ -45,7 +55,7 @@ def to_dict(self):
4555
"test_suite": self.test_suite,
4656
"tested_pairs": self.tested_pairs,
4757
"pairs_to_test": self.pairs_to_test,
48-
"untested_edges": self.untested_edges,
58+
"untested_pairs": self.untested_pairs,
4959
"dag_adequacy": self.dag_adequacy,
5060
}
5161

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def test_data_adequacy_cateogorical(self):
107107
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100},
108108
)
109109

110-
def test_dag_adequacy(self):
110+
def test_dag_adequacy_dependent(self):
111111
base_test_case = BaseTestCase(
112112
treatment_variable="test_input",
113-
outcome_variable="test_output",
113+
outcome_variable="B",
114114
effect=None,
115115
)
116116
causal_test_case = CausalTestCase(
@@ -128,29 +128,127 @@ def test_dag_adequacy(self):
128128
{
129129
"causal_dag": self.json_class.causal_specification.causal_dag,
130130
"test_suite": test_suite,
131-
"tested_pairs": {("test_input", "test_output")},
131+
"tested_pairs": {("test_input", "B")},
132132
"pairs_to_test": {
133+
("B", "C"),
133134
("test_input_no_dist", "test_input"),
135+
("C", "test_output"),
134136
("test_input", "B"),
135-
("test_input_no_dist", "C"),
136137
("test_input_no_dist", "B"),
138+
("test_input", "test_output"),
137139
("test_input", "C"),
138-
("B", "C"),
139140
("test_input_no_dist", "test_output"),
141+
("B", "test_output"),
142+
("test_input_no_dist", "C"),
143+
},
144+
"untested_pairs": {
145+
("B", "C"),
146+
("test_input_no_dist", "test_input"),
147+
("C", "test_output"),
148+
("test_input_no_dist", "B"),
140149
("test_input", "test_output"),
150+
("test_input", "C"),
151+
("test_input_no_dist", "test_output"),
152+
("B", "test_output"),
153+
("test_input_no_dist", "C"),
154+
},
155+
"dag_adequacy": 0.1,
156+
},
157+
)
158+
159+
def test_dag_adequacy_independent(self):
160+
base_test_case = BaseTestCase(
161+
treatment_variable="test_input",
162+
outcome_variable="C",
163+
effect=None,
164+
)
165+
causal_test_case = CausalTestCase(
166+
base_test_case=base_test_case,
167+
expected_causal_effect=None,
168+
estimate_type=None,
169+
)
170+
test_suite = CausalTestSuite()
171+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
172+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
173+
dag_adequacy.measure_adequacy()
174+
print(dag_adequacy.to_dict())
175+
self.assertEqual(
176+
dag_adequacy.to_dict(),
177+
{
178+
"causal_dag": self.json_class.causal_specification.causal_dag,
179+
"test_suite": test_suite,
180+
"tested_pairs": {("test_input", "C")},
181+
"pairs_to_test": {
182+
("B", "C"),
183+
("test_input_no_dist", "test_input"),
141184
("C", "test_output"),
185+
("test_input", "B"),
186+
("test_input_no_dist", "B"),
187+
("test_input", "test_output"),
188+
("test_input", "C"),
189+
("test_input_no_dist", "test_output"),
142190
("B", "test_output"),
191+
("test_input_no_dist", "C"),
143192
},
144-
"untested_edges": {
193+
"untested_pairs": {
194+
("B", "C"),
145195
("test_input_no_dist", "test_input"),
196+
("C", "test_output"),
197+
("test_input_no_dist", "B"),
198+
("test_input", "test_output"),
146199
("test_input", "B"),
200+
("test_input_no_dist", "test_output"),
201+
("B", "test_output"),
147202
("test_input_no_dist", "C"),
203+
},
204+
"dag_adequacy": 0.1,
205+
},
206+
)
207+
208+
def test_dag_adequacy_independent_other_way(self):
209+
base_test_case = BaseTestCase(
210+
treatment_variable="C",
211+
outcome_variable="test_input",
212+
effect=None,
213+
)
214+
causal_test_case = CausalTestCase(
215+
base_test_case=base_test_case,
216+
expected_causal_effect=None,
217+
estimate_type=None,
218+
)
219+
test_suite = CausalTestSuite()
220+
test_suite.add_test_object(base_test_case, causal_test_case, None, None)
221+
dag_adequacy = DAGAdequacy(self.json_class.causal_specification.causal_dag, test_suite)
222+
dag_adequacy.measure_adequacy()
223+
print(dag_adequacy.to_dict())
224+
self.assertEqual(
225+
dag_adequacy.to_dict(),
226+
{
227+
"causal_dag": self.json_class.causal_specification.causal_dag,
228+
"test_suite": test_suite,
229+
"tested_pairs": {("test_input", "C")},
230+
"pairs_to_test": {
231+
("B", "C"),
232+
("test_input_no_dist", "test_input"),
233+
("C", "test_output"),
234+
("test_input", "B"),
148235
("test_input_no_dist", "B"),
236+
("test_input", "test_output"),
149237
("test_input", "C"),
150-
("B", "C"),
151238
("test_input_no_dist", "test_output"),
239+
("B", "test_output"),
240+
("test_input_no_dist", "C"),
241+
},
242+
"untested_pairs": {
243+
("B", "C"),
244+
("test_input_no_dist", "test_input"),
152245
("C", "test_output"),
246+
("test_input_no_dist", "B"),
247+
("test_input", "test_output"),
248+
("test_input", "B"),
249+
("test_input_no_dist", "test_output"),
153250
("B", "test_output"),
251+
("test_input_no_dist", "C"),
154252
},
155253
"dag_adequacy": 0.1,
156254
},

0 commit comments

Comments
 (0)