Skip to content

Commit 9c67392

Browse files
author
AndrewC19
committed
ShouldNotCause MR works
1 parent 8bff2a9 commit 9c67392

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
8989
data_collector.control_input_configuration = control_input_config
9090
data_collector.treatment_input_configuration = treatment_input_config
9191
metamorphic_test_results_df = data_collector.collect_data()
92-
9392
# Apply assertion to control and treatment outputs
9493
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
9594
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
9695
if not self.assertion(control_output, treatment_output):
96+
print(metamorphic_test.output)
97+
print(control_output, treatment_output)
9798
test_results["fail"].append(metamorphic_test)
9899
else:
99100
test_results["pass"].append(metamorphic_test)
@@ -112,7 +113,6 @@ def test_oracle(self, test_results):
112113
...
113114

114115

115-
@dataclass(order=True)
116116
class ShouldCause(MetamorphicRelation):
117117
"""Class representing a should cause metamorphic relation."""
118118

@@ -132,6 +132,25 @@ def __str__(self):
132132
return formatted_str
133133

134134

135+
class ShouldNotCause(MetamorphicRelation):
136+
"""Class representing a should cause metamorphic relation."""
137+
138+
def assertion(self, source_output, follow_up_output):
139+
"""If there is a causal effect, the outputs should not be the same."""
140+
return source_output == follow_up_output
141+
142+
def test_oracle(self, test_results):
143+
"""A single passing test is sufficient to show presence of a causal effect."""
144+
assert len(test_results["fail"]) == 0,\
145+
f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
146+
147+
def __str__(self):
148+
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
149+
if self.adjustment_vars:
150+
formatted_str += f" | {self.adjustment_vars}"
151+
return formatted_str
152+
153+
135154
@dataclass(order=True)
136155
class MetamorphicTest:
137156
"""Class representing a metamorphic test case."""

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import unittest
22
import os
3-
43
import pandas as pd
4+
from itertools import combinations
55

66
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
77
from causal_testing.specification.causal_dag import CausalDAG
88
from causal_testing.specification.causal_specification import Scenario
9-
from causal_testing.specification.metamorphic_relation import ShouldCause
9+
from causal_testing.specification.metamorphic_relation import ShouldCause, ShouldNotCause
1010
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1111
from causal_testing.specification.variable import Input, Output
1212

@@ -18,7 +18,7 @@ def program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
1818
M = 3*Z + X3
1919
if Y is None:
2020
Y = M/2
21-
return {'Z': Z, 'M': M, 'Y': Y}
21+
return {'X1': X1, 'X2': X2, 'X3': X3, 'Z': Z, 'M': M, 'Y': Y}
2222

2323

2424
def buggy_program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
@@ -28,7 +28,7 @@ def buggy_program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
2828
M = 3*Z + X3
2929
if Y is None:
3030
Y = M/2
31-
return {'Z': Z, 'M': M, 'Y': Y}
31+
return {'X1': X1, 'X2': X2, 'X3': X3, 'Z': Z, 'M': M, 'Y': Y}
3232

3333

3434
class ProgramUnderTestEDC(ExperimentalDataCollector):
@@ -52,7 +52,7 @@ class TestMetamorphicRelation(unittest.TestCase):
5252
def setUp(self) -> None:
5353
temp_dir_path = create_temp_dir_if_non_existent()
5454
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
55-
dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X1 -> M; X2 -> Z; X3 -> M;}"""
55+
dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X2 -> Z; X3 -> M;}"""
5656
with open(self.dag_dot_path, "w") as f:
5757
f.write(dag_dot)
5858

@@ -69,7 +69,8 @@ def setUp(self) -> None:
6969
self.default_control_input_config,
7070
self.default_treatment_input_config)
7171

72-
def test_should_cause_metamorphic_relations_should_pass(self):
72+
def test_should_cause_metamorphic_relations_correct_spec(self):
73+
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program."""
7374
causal_dag = CausalDAG(self.dag_dot_path)
7475
for edge in causal_dag.graph.edges:
7576
(u, v) = edge
@@ -78,9 +79,30 @@ def test_should_cause_metamorphic_relations_should_pass(self):
7879
test_results = should_cause_MR.execute_tests(self.data_collector)
7980
should_cause_MR.test_oracle(test_results)
8081

82+
def test_should_not_cause_metamorphic_relations_correct_spec(self):
83+
"""Test if the ShouldNotCause MR passes all metamorphic tests where the DAG perfectly represents the program."""
84+
causal_dag = CausalDAG(self.dag_dot_path)
85+
for node_pair in combinations(causal_dag.graph.nodes, 2):
86+
(u, v) = node_pair
87+
# Get all pairs of nodes which don't form an edge
88+
if ((u, v) not in causal_dag.graph.edges) and ((v, u) not in causal_dag.graph.edges):
89+
# Check both directions if there is no causality
90+
# This can be done more efficiently by ignoring impossible directions (output --> input)
91+
adj_set = list(causal_dag.direct_effect_adjustment_sets([u], [v])[0])
92+
u_should_not_cause_v_MR = ShouldNotCause(u, v, adj_set, causal_dag)
93+
v_should_not_cause_u_MR = ShouldNotCause(v, u, adj_set, causal_dag)
94+
u_should_not_cause_v_MR.generate_follow_up(10, -100, 100)
95+
v_should_not_cause_u_MR.generate_follow_up(10, -100, 100)
96+
u_should_not_cause_v_test_results = u_should_not_cause_v_MR.execute_tests(self.data_collector)
97+
v_should_not_cause_u_test_results = v_should_not_cause_u_MR.execute_tests(self.data_collector)
98+
u_should_not_cause_v_MR.test_oracle(u_should_not_cause_v_test_results)
99+
v_should_not_cause_u_MR.test_oracle(v_should_not_cause_u_test_results)
100+
81101
def test_should_cause_metamorphic_relation_missing_relationship(self):
82102
"""Test whether the ShouldCause MR catches missing relationships in the DAG."""
83103
causal_dag = CausalDAG(self.dag_dot_path)
104+
105+
# Replace the data collector with one that runs a buggy program in which X1 and X2 do not affect Z
84106
self.data_collector = BuggyProgramUnderTestEDC(self.scenario,
85107
self.default_control_input_config,
86108
self.default_treatment_input_config)

0 commit comments

Comments
 (0)