Skip to content

Commit df3ec81

Browse files
author
AndrewC19
committed
Fixed flaky metamorphic relations tests.
Re-defined equality for the metamorphic relation class and this ensures the order of the adjustment set doesn't matter.
1 parent 393f6d9 commit df3ec81

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ def test_oracle(self, test_results):
107107
This method must raise an assertion, not return a bool."""
108108
...
109109

110+
def __eq__(self, other):
111+
same_type = self.__class__ == other.__class__
112+
same_treatment = self.treatment_var == other.treatment_var
113+
same_output = self.output_var == other.output_var
114+
same_adjustment_set = set(self.adjustment_vars) == set(other.adjustment_vars)
115+
return same_type and same_treatment and same_output and same_adjustment_set
116+
110117

111118
class ShouldCause(MetamorphicRelation):
112119
"""Class representing a should cause metamorphic relation."""

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,45 @@ def test_all_metamorphic_relations_implied_by_dag(self):
164164

165165
self.assertEqual(extra_snc_relations, [])
166166
self.assertEqual(missing_snc_relations, [])
167+
168+
def test_equivalent_metamorphic_relations(self):
169+
dag = CausalDAG(self.dag_dot_path)
170+
sc_mr_a = ShouldCause("X", "Y", ["A", "B", "C"], dag)
171+
sc_mr_b = ShouldCause("X", "Y", ["A", "B", "C"], dag)
172+
self.assertEqual(sc_mr_a == sc_mr_b, True)
173+
174+
def test_equivalent_metamorphic_relations_empty_adjustment_set(self):
175+
dag = CausalDAG(self.dag_dot_path)
176+
sc_mr_a = ShouldCause("X", "Y", [], dag)
177+
sc_mr_b = ShouldCause("X", "Y", [], dag)
178+
self.assertEqual(sc_mr_a == sc_mr_b, True)
179+
180+
def test_equivalent_metamorphic_relations_different_order_adjustment_set(self):
181+
dag = CausalDAG(self.dag_dot_path)
182+
sc_mr_a = ShouldCause("X", "Y", ["A", "B", "C"], dag)
183+
sc_mr_b = ShouldCause("X", "Y", ["C", "A", "B"], dag)
184+
self.assertEqual(sc_mr_a == sc_mr_b, True)
185+
186+
def test_different_metamorphic_relations_empty_adjustment_set_different_outcome(self):
187+
dag = CausalDAG(self.dag_dot_path)
188+
sc_mr_a = ShouldCause("X", "Z", [], dag)
189+
sc_mr_b = ShouldCause("X", "Y", [], dag)
190+
self.assertEqual(sc_mr_a == sc_mr_b, False)
191+
192+
def test_different_metamorphic_relations_empty_adjustment_set_different_treatment(self):
193+
dag = CausalDAG(self.dag_dot_path)
194+
sc_mr_a = ShouldCause("X", "Y", [], dag)
195+
sc_mr_b = ShouldCause("Z", "Y", [], dag)
196+
self.assertEqual(sc_mr_a == sc_mr_b, False)
197+
198+
def test_different_metamorphic_relations_empty_adjustment_set_adjustment_set(self):
199+
dag = CausalDAG(self.dag_dot_path)
200+
sc_mr_a = ShouldCause("X", "Y", ["A"], dag)
201+
sc_mr_b = ShouldCause("X", "Y", [], dag)
202+
self.assertEqual(sc_mr_a == sc_mr_b, False)
203+
204+
def test_different_metamorphic_relations_different_type(self):
205+
dag = CausalDAG(self.dag_dot_path)
206+
sc_mr_a = ShouldCause("X", "Y", [], dag)
207+
sc_mr_b = ShouldNotCause("X", "Y", [], dag)
208+
self.assertEqual(sc_mr_a == sc_mr_b, False)

0 commit comments

Comments
 (0)