Skip to content

Commit 73eb5f1

Browse files
committed
Removed unnecessary methods and arguments from metamorphic relation
1 parent b3b5261 commit 73eb5f1

File tree

2 files changed

+53
-86
lines changed

2 files changed

+53
-86
lines changed

causal_testing/testing/metamorphic_relation.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from dataclasses import dataclass
7-
from abc import abstractmethod
87
from typing import Iterable
98
from itertools import combinations
109
import argparse
@@ -26,18 +25,6 @@ class MetamorphicRelation:
2625

2726
base_test_case: BaseTestCase
2827
adjustment_vars: Iterable[Node]
29-
dag: CausalDAG
30-
tests: Iterable = None
31-
32-
@abstractmethod
33-
def to_json_stub(self, skip=True) -> dict:
34-
"""Convert to a JSON frontend stub string for user customisation"""
35-
36-
@abstractmethod
37-
def test_oracle(self, test_results):
38-
"""A test oracle that assert whether the MR holds or not based on ALL test results.
39-
40-
This method must raise an assertion, not return a bool."""
4128

4229
def __eq__(self, other):
4330
same_type = self.__class__ == other.__class__
@@ -51,16 +38,6 @@ def __eq__(self, other):
5138
class ShouldCause(MetamorphicRelation):
5239
"""Class representing a should cause metamorphic relation."""
5340

54-
def assertion(self, source_output, follow_up_output):
55-
"""If there is a causal effect, the outputs should not be the same."""
56-
return source_output != follow_up_output
57-
58-
def test_oracle(self, test_results):
59-
"""A single passing test is sufficient to show presence of a causal effect."""
60-
assert len(test_results["fail"]) < len(
61-
self.tests
62-
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
63-
6441
def to_json_stub(self, skip=True) -> dict:
6542
"""Convert to a JSON frontend stub string for user customisation"""
6643
return {
@@ -84,16 +61,6 @@ def __str__(self):
8461
class ShouldNotCause(MetamorphicRelation):
8562
"""Class representing a should cause metamorphic relation."""
8663

87-
def assertion(self, source_output, follow_up_output):
88-
"""If there is a causal effect, the outputs should not be the same."""
89-
return source_output == follow_up_output
90-
91-
def test_oracle(self, test_results):
92-
"""A single passing test is sufficient to show presence of a causal effect."""
93-
assert (
94-
len(test_results["fail"]) == 0
95-
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
96-
9764
def to_json_stub(self, skip=True) -> dict:
9865
"""Convert to a JSON frontend stub string for user customisation"""
9966
return {
@@ -140,30 +107,30 @@ def generate_metamorphic_relation(
140107
if u in nx.ancestors(dag.graph, v):
141108
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
142109
if adj_sets:
143-
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
110+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))
144111

145112
# Case 2: V --> ... --> U
146113
elif v in nx.ancestors(dag.graph, u):
147114
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
148115
if adj_sets:
149-
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0]), dag))
116+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0])))
150117

151118
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
152119
# Only make one MR since V _||_ U == U _||_ V
153120
else:
154121
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
155122
if adj_sets:
156-
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
123+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0])))
157124

158125
# Create a ShouldCause relation for each edge (u, v) or (v, u)
159126
elif (u, v) in dag.graph.edges:
160127
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
161128
if adj_sets:
162-
metamorphic_relations.append(ShouldCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
129+
metamorphic_relations.append(ShouldCause(BaseTestCase(u, v), list(adj_sets[0])))
163130
else:
164131
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
165132
if adj_sets:
166-
metamorphic_relations.append(ShouldCause(BaseTestCase(v, u), list(adj_sets[0]), dag))
133+
metamorphic_relations.append(ShouldCause(BaseTestCase(v, u), list(adj_sets[0])))
167134
return metamorphic_relations
168135

169136

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_should_not_cause_json_stub(self):
4747
causal_dag = CausalDAG(self.dag_dot_path)
4848
causal_dag.graph.remove_nodes_from(["X2", "X3"])
4949
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
50-
should_not_cause_MR = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set, causal_dag)
50+
should_not_cause_MR = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
5151
self.assertEqual(
5252
should_not_cause_MR.to_json_stub(),
5353
{
@@ -70,7 +70,7 @@ def test_should_cause_json_stub(self):
7070
causal_dag = CausalDAG(self.dag_dot_path)
7171
causal_dag.graph.remove_nodes_from(["X2", "X3"])
7272
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
73-
should_cause_MR = ShouldCause(BaseTestCase("X1", "Z"), adj_set, causal_dag)
73+
should_cause_MR = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
7474
self.assertEqual(
7575
should_cause_MR.to_json_stub(),
7676
{
@@ -94,12 +94,12 @@ def test_all_metamorphic_relations_implied_by_dag(self):
9494

9595
# Check all ShouldCause relations are present and no extra
9696
expected_should_cause_relations = [
97-
ShouldCause(BaseTestCase("X1", "Z"), [], dag),
98-
ShouldCause(BaseTestCase("Z", "M"), [], dag),
99-
ShouldCause(BaseTestCase("M", "Y"), ["Z"], dag),
100-
ShouldCause(BaseTestCase("Z", "Y"), ["M"], dag),
101-
ShouldCause(BaseTestCase("X2", "Z"), [], dag),
102-
ShouldCause(BaseTestCase("X3", "M"), [], dag),
97+
ShouldCause(BaseTestCase("X1", "Z"), []),
98+
ShouldCause(BaseTestCase("Z", "M"), []),
99+
ShouldCause(BaseTestCase("M", "Y"), ["Z"]),
100+
ShouldCause(BaseTestCase("Z", "Y"), ["M"]),
101+
ShouldCause(BaseTestCase("X2", "Z"), []),
102+
ShouldCause(BaseTestCase("X3", "M"), []),
103103
]
104104

105105
extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations]
@@ -110,15 +110,15 @@ def test_all_metamorphic_relations_implied_by_dag(self):
110110

111111
# Check all ShouldNotCause relations are present and no extra
112112
expected_should_not_cause_relations = [
113-
ShouldNotCause(BaseTestCase("X1", "X2"), [], dag),
114-
ShouldNotCause(BaseTestCase("X1", "X3"), [], dag),
115-
ShouldNotCause(BaseTestCase("X1", "M"), ["Z"], dag),
116-
ShouldNotCause(BaseTestCase("X1", "Y"), ["Z"], dag),
117-
ShouldNotCause(BaseTestCase("X2", "X3"), [], dag),
118-
ShouldNotCause(BaseTestCase("X2", "M"), ["Z"], dag),
119-
ShouldNotCause(BaseTestCase("X2", "Y"), ["Z"], dag),
120-
ShouldNotCause(BaseTestCase("X3", "Y"), ["M", "Z"], dag),
121-
ShouldNotCause(BaseTestCase("Z", "X3"), [], dag),
113+
ShouldNotCause(BaseTestCase("X1", "X2"), []),
114+
ShouldNotCause(BaseTestCase("X1", "X3"), []),
115+
ShouldNotCause(BaseTestCase("X1", "M"), ["Z"]),
116+
ShouldNotCause(BaseTestCase("X1", "Y"), ["Z"]),
117+
ShouldNotCause(BaseTestCase("X2", "X3"), []),
118+
ShouldNotCause(BaseTestCase("X2", "M"), ["Z"]),
119+
ShouldNotCause(BaseTestCase("X2", "Y"), ["Z"]),
120+
ShouldNotCause(BaseTestCase("X3", "Y"), ["M", "Z"]),
121+
ShouldNotCause(BaseTestCase("Z", "X3"), []),
122122
]
123123

124124
extra_snc_relations = [
@@ -140,12 +140,12 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self):
140140

141141
# Check all ShouldCause relations are present and no extra
142142
expected_should_cause_relations = [
143-
ShouldCause(BaseTestCase("X1", "Z"), [], dag),
144-
ShouldCause(BaseTestCase("Z", "M"), [], dag),
145-
ShouldCause(BaseTestCase("M", "Y"), ["Z"], dag),
146-
ShouldCause(BaseTestCase("Z", "Y"), ["M"], dag),
147-
ShouldCause(BaseTestCase("X2", "Z"), [], dag),
148-
ShouldCause(BaseTestCase("X3", "M"), [], dag),
143+
ShouldCause(BaseTestCase("X1", "Z"), []),
144+
ShouldCause(BaseTestCase("Z", "M"), []),
145+
ShouldCause(BaseTestCase("M", "Y"), ["Z"]),
146+
ShouldCause(BaseTestCase("Z", "Y"), ["M"]),
147+
ShouldCause(BaseTestCase("X2", "Z"), []),
148+
ShouldCause(BaseTestCase("X3", "M"), []),
149149
]
150150

151151
extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations]
@@ -156,15 +156,15 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self):
156156

157157
# Check all ShouldNotCause relations are present and no extra
158158
expected_should_not_cause_relations = [
159-
ShouldNotCause(BaseTestCase("X1", "X2"), [], dag),
160-
ShouldNotCause(BaseTestCase("X1", "X3"), [], dag),
161-
ShouldNotCause(BaseTestCase("X1", "M"), ["Z"], dag),
162-
ShouldNotCause(BaseTestCase("X1", "Y"), ["Z"], dag),
163-
ShouldNotCause(BaseTestCase("X2", "X3"), [], dag),
164-
ShouldNotCause(BaseTestCase("X2", "M"), ["Z"], dag),
165-
ShouldNotCause(BaseTestCase("X2", "Y"), ["Z"], dag),
166-
ShouldNotCause(BaseTestCase("X3", "Y"), ["M", "Z"], dag),
167-
ShouldNotCause(BaseTestCase("Z", "X3"), [], dag),
159+
ShouldNotCause(BaseTestCase("X1", "X2"), []),
160+
ShouldNotCause(BaseTestCase("X1", "X3"), []),
161+
ShouldNotCause(BaseTestCase("X1", "M"), ["Z"]),
162+
ShouldNotCause(BaseTestCase("X1", "Y"), ["Z"]),
163+
ShouldNotCause(BaseTestCase("X2", "X3"), []),
164+
ShouldNotCause(BaseTestCase("X2", "M"), ["Z"]),
165+
ShouldNotCause(BaseTestCase("X2", "Y"), ["Z"]),
166+
ShouldNotCause(BaseTestCase("X3", "Y"), ["M", "Z"]),
167+
ShouldNotCause(BaseTestCase("Z", "X3"), []),
168168
]
169169

170170
extra_snc_relations = [
@@ -188,7 +188,7 @@ def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self):
188188
self.assertEqual(
189189
should_cause_relations,
190190
[
191-
ShouldCause(BaseTestCase("a", "b"), [], dag),
191+
ShouldCause(BaseTestCase("a", "b"), []),
192192
],
193193
)
194194
self.assertEqual(
@@ -201,47 +201,47 @@ def test_generate_metamorphic_relation_(self):
201201
[metamorphic_relation] = generate_metamorphic_relation(("X1", "Z"), dag)
202202
self.assertEqual(
203203
metamorphic_relation,
204-
ShouldCause(BaseTestCase("X1", "Z"), [], dag),
204+
ShouldCause(BaseTestCase("X1", "Z"), []),
205205
)
206206

207207
def test_equivalent_metamorphic_relations(self):
208208
dag = CausalDAG(self.dag_dot_path)
209-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"], dag)
210-
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"], dag)
209+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
210+
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
211211
self.assertEqual(sc_mr_a == sc_mr_b, True)
212212

213213
def test_equivalent_metamorphic_relations_empty_adjustment_set(self):
214214
dag = CausalDAG(self.dag_dot_path)
215-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [], dag)
216-
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [], dag)
215+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [])
216+
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [])
217217
self.assertEqual(sc_mr_a == sc_mr_b, True)
218218

219219
def test_equivalent_metamorphic_relations_different_order_adjustment_set(self):
220220
dag = CausalDAG(self.dag_dot_path)
221-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"], dag)
222-
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), ["C", "A", "B"], dag)
221+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
222+
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), ["C", "A", "B"])
223223
self.assertEqual(sc_mr_a == sc_mr_b, True)
224224

225225
def test_different_metamorphic_relations_empty_adjustment_set_different_outcome(self):
226226
dag = CausalDAG(self.dag_dot_path)
227-
sc_mr_a = ShouldCause(BaseTestCase("X", "Z"), [], dag)
228-
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [], dag)
227+
sc_mr_a = ShouldCause(BaseTestCase("X", "Z"), [])
228+
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [])
229229
self.assertEqual(sc_mr_a == sc_mr_b, False)
230230

231231
def test_different_metamorphic_relations_empty_adjustment_set_different_treatment(self):
232232
dag = CausalDAG(self.dag_dot_path)
233-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [], dag)
234-
sc_mr_b = ShouldCause(BaseTestCase("Z", "Y"), [], dag)
233+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [])
234+
sc_mr_b = ShouldCause(BaseTestCase("Z", "Y"), [])
235235
self.assertEqual(sc_mr_a == sc_mr_b, False)
236236

237237
def test_different_metamorphic_relations_empty_adjustment_set_adjustment_set(self):
238238
dag = CausalDAG(self.dag_dot_path)
239-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A"], dag)
240-
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [], dag)
239+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), ["A"])
240+
sc_mr_b = ShouldCause(BaseTestCase("X", "Y"), [])
241241
self.assertEqual(sc_mr_a == sc_mr_b, False)
242242

243243
def test_different_metamorphic_relations_different_type(self):
244244
dag = CausalDAG(self.dag_dot_path)
245-
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [], dag)
246-
sc_mr_b = ShouldNotCause(BaseTestCase("X", "Y"), [], dag)
245+
sc_mr_a = ShouldCause(BaseTestCase("X", "Y"), [])
246+
sc_mr_b = ShouldNotCause(BaseTestCase("X", "Y"), [])
247247
self.assertEqual(sc_mr_a == sc_mr_b, False)

0 commit comments

Comments
 (0)