Skip to content

Commit ea8b3ab

Browse files
author
AndrewC19
committed
Implemented and tested MR generation from dag method
1 parent 2131be7 commit ea8b3ab

File tree

2 files changed

+92
-9
lines changed

2 files changed

+92
-9
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from itertools import combinations
55
import numpy as np
66
import pandas as pd
7+
import networkx as nx
78

89
from causal_testing.specification.causal_specification import CausalDAG, Node
910
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1011

12+
1113
@dataclass(order=True)
1214
class MetamorphicRelation:
1315
"""Class representing a metamorphic relation."""
@@ -89,12 +91,12 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
8991
data_collector.control_input_configuration = control_input_config
9092
data_collector.treatment_input_configuration = treatment_input_config
9193
metamorphic_test_results_df = data_collector.collect_data()
94+
9295
# Apply assertion to control and treatment outputs
9396
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
9497
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
98+
9599
if not self.assertion(control_output, treatment_output):
96-
print(metamorphic_test.output)
97-
print(control_output, treatment_output)
98100
test_results["fail"].append(metamorphic_test)
99101
else:
100102
test_results["pass"].append(metamorphic_test)
@@ -166,3 +168,43 @@ def __str__(self):
166168
f"Other inputs: {self.other_inputs}\n" \
167169
f"Output: {self.output}" \
168170
f"Metamorphic Relation: {self.relation}"
171+
172+
173+
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
174+
"""Construct a list of metamorphic relations implied by the Causal DAG.
175+
176+
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
177+
relation for every conditional independence relation.
178+
"""
179+
metamorphic_relations = []
180+
for node_pair in combinations(dag.graph.nodes, 2):
181+
(u, v) = node_pair
182+
183+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
184+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
185+
186+
# Case 1: U --> ... --> V
187+
if u in nx.ancestors(dag.graph, v):
188+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
189+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
190+
191+
# Case 2: V --> ... --> U
192+
elif v in nx.ancestors(dag.graph, u):
193+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
194+
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
195+
196+
# Case 3: V _||_ U (neither is a predecessor)
197+
# Only make one MR since V _||_ U == U _||_ V
198+
else:
199+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
200+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
201+
202+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
203+
elif (u, v) in dag.graph.edges:
204+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
205+
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
206+
else:
207+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
208+
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
209+
210+
return metamorphic_relations

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import unittest
22
import os
3+
34
import pandas as pd
45
from itertools import combinations
56

6-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
7+
from tests.test_helpers import create_temp_dir_if_non_existent
78
from causal_testing.specification.causal_dag import CausalDAG
89
from causal_testing.specification.causal_specification import Scenario
9-
from causal_testing.specification.metamorphic_relation import ShouldCause, ShouldNotCause
10+
from causal_testing.specification.metamorphic_relation import (ShouldCause, ShouldNotCause,
11+
generate_metamorphic_relations)
1012
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1113
from causal_testing.specification.variable import Input, Output
1214

@@ -74,7 +76,8 @@ def test_should_cause_metamorphic_relations_correct_spec(self):
7476
causal_dag = CausalDAG(self.dag_dot_path)
7577
for edge in causal_dag.graph.edges:
7678
(u, v) = edge
77-
should_cause_MR = ShouldCause(u, v, None, causal_dag)
79+
adj_set = list(causal_dag.direct_effect_adjustment_sets([u], [v])[0])
80+
should_cause_MR = ShouldCause(u, v, adj_set, causal_dag)
7881
should_cause_MR.generate_follow_up(10, -10.0, 10.0, 1)
7982
test_results = should_cause_MR.execute_tests(self.data_collector)
8083
should_cause_MR.test_oracle(test_results)
@@ -88,9 +91,10 @@ def test_should_not_cause_metamorphic_relations_correct_spec(self):
8891
if ((u, v) not in causal_dag.graph.edges) and ((v, u) not in causal_dag.graph.edges):
8992
# Check both directions if there is no causality
9093
# This can be done more efficiently by ignoring impossible directions (output --> input)
91-
adj_set = list(causal_dag.direct_effect_adjustment_sets([u], [v])[0])
92-
u_should_not_cause_v_MR = ShouldNotCause(u, v, adj_set, causal_dag)
93-
v_should_not_cause_u_MR = ShouldNotCause(v, u, adj_set, causal_dag)
94+
adj_set_u_to_v = list(causal_dag.direct_effect_adjustment_sets([u], [v])[0])
95+
u_should_not_cause_v_MR = ShouldNotCause(u, v, adj_set_u_to_v, causal_dag)
96+
adj_set_v_to_u = list(causal_dag.direct_effect_adjustment_sets([v], [u])[0])
97+
v_should_not_cause_u_MR = ShouldNotCause(v, u, adj_set_v_to_u, causal_dag)
9498
u_should_not_cause_v_MR.generate_follow_up(10, -100, 100)
9599
v_should_not_cause_u_MR.generate_follow_up(10, -100, 100)
96100
u_should_not_cause_v_test_results = u_should_not_cause_v_MR.execute_tests(self.data_collector)
@@ -115,4 +119,41 @@ def test_should_cause_metamorphic_relation_missing_relationship(self):
115119
self.assertRaises(AssertionError, X1_should_cause_Z_MR.test_oracle, X1_should_cause_Z_test_results)
116120
self.assertRaises(AssertionError, X2_should_cause_Z_MR.test_oracle, X2_should_cause_Z_test_results)
117121

118-
122+
def test_all_metamorphic_relations_implied_by_dag(self):
123+
dag = CausalDAG(self.dag_dot_path)
124+
metamorphic_relations = generate_metamorphic_relations(dag)
125+
should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)]
126+
should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)]
127+
128+
# Check all ShouldCause relations are present and no extra
129+
expected_should_cause_relations = [ShouldCause('X1', 'Z', [], dag),
130+
ShouldCause('Z', 'M', [], dag),
131+
ShouldCause('M', 'Y', [], dag),
132+
ShouldCause('X2', 'Z', [], dag),
133+
ShouldCause('X3', 'M', [], dag)]
134+
135+
extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations]
136+
missing_sc_relations = [escr for escr in expected_should_cause_relations if escr not in should_cause_relations]
137+
138+
self.assertEqual(extra_sc_relations, [])
139+
self.assertEqual(missing_sc_relations, [])
140+
141+
# Check all ShouldNotCause relations are present and no extra
142+
expected_should_not_cause_relations = [ShouldNotCause('X1', 'X2', [], dag),
143+
ShouldNotCause('X1', 'X3', [], dag),
144+
ShouldNotCause('X1', 'M', ['Z'], dag),
145+
ShouldNotCause('X1', 'Y', ['M'], dag),
146+
ShouldNotCause('X2', 'X3', [], dag),
147+
ShouldNotCause('X2', 'M', ['Z'], dag),
148+
ShouldNotCause('X2', 'Y', ['M'], dag),
149+
ShouldNotCause('X3', 'Y', ['M'], dag),
150+
ShouldNotCause('Z', 'Y', ['M'], dag),
151+
ShouldNotCause('Z', 'X3', [], dag)]
152+
153+
extra_snc_relations = [sncr for sncr in should_not_cause_relations if sncr
154+
not in expected_should_not_cause_relations]
155+
missing_snc_relations = [esncr for esncr in expected_should_not_cause_relations if esncr
156+
not in should_not_cause_relations]
157+
158+
self.assertEqual(extra_snc_relations, [])
159+
self.assertEqual(missing_snc_relations, [])

0 commit comments

Comments
 (0)