1
1
import unittest
2
2
import os
3
+
3
4
import pandas as pd
4
5
from itertools import combinations
5
6
6
- from tests .test_helpers import create_temp_dir_if_non_existent , remove_temp_dir_if_existent
7
+ from tests .test_helpers import create_temp_dir_if_non_existent
7
8
from causal_testing .specification .causal_dag import CausalDAG
8
9
from causal_testing .specification .causal_specification import Scenario
9
- from causal_testing .specification .metamorphic_relation import ShouldCause , ShouldNotCause
10
+ from causal_testing .specification .metamorphic_relation import (ShouldCause , ShouldNotCause ,
11
+ generate_metamorphic_relations )
10
12
from causal_testing .data_collection .data_collector import ExperimentalDataCollector
11
13
from causal_testing .specification .variable import Input , Output
12
14
@@ -74,7 +76,8 @@ def test_should_cause_metamorphic_relations_correct_spec(self):
74
76
causal_dag = CausalDAG (self .dag_dot_path )
75
77
for edge in causal_dag .graph .edges :
76
78
(u , v ) = edge
77
- should_cause_MR = ShouldCause (u , v , None , causal_dag )
79
+ adj_set = list (causal_dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
80
+ should_cause_MR = ShouldCause (u , v , adj_set , causal_dag )
78
81
should_cause_MR .generate_follow_up (10 , - 10.0 , 10.0 , 1 )
79
82
test_results = should_cause_MR .execute_tests (self .data_collector )
80
83
should_cause_MR .test_oracle (test_results )
@@ -88,9 +91,10 @@ def test_should_not_cause_metamorphic_relations_correct_spec(self):
88
91
if ((u , v ) not in causal_dag .graph .edges ) and ((v , u ) not in causal_dag .graph .edges ):
89
92
# Check both directions if there is no causality
90
93
# This can be done more efficiently by ignoring impossible directions (output --> input)
91
- adj_set = list (causal_dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
92
- u_should_not_cause_v_MR = ShouldNotCause (u , v , adj_set , causal_dag )
93
- v_should_not_cause_u_MR = ShouldNotCause (v , u , adj_set , causal_dag )
94
+ adj_set_u_to_v = list (causal_dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
95
+ u_should_not_cause_v_MR = ShouldNotCause (u , v , adj_set_u_to_v , causal_dag )
96
+ adj_set_v_to_u = list (causal_dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
97
+ v_should_not_cause_u_MR = ShouldNotCause (v , u , adj_set_v_to_u , causal_dag )
94
98
u_should_not_cause_v_MR .generate_follow_up (10 , - 100 , 100 )
95
99
v_should_not_cause_u_MR .generate_follow_up (10 , - 100 , 100 )
96
100
u_should_not_cause_v_test_results = u_should_not_cause_v_MR .execute_tests (self .data_collector )
@@ -115,4 +119,41 @@ def test_should_cause_metamorphic_relation_missing_relationship(self):
115
119
self .assertRaises (AssertionError , X1_should_cause_Z_MR .test_oracle , X1_should_cause_Z_test_results )
116
120
self .assertRaises (AssertionError , X2_should_cause_Z_MR .test_oracle , X2_should_cause_Z_test_results )
117
121
118
-
122
+ def test_all_metamorphic_relations_implied_by_dag (self ):
123
+ dag = CausalDAG (self .dag_dot_path )
124
+ metamorphic_relations = generate_metamorphic_relations (dag )
125
+ should_cause_relations = [mr for mr in metamorphic_relations if isinstance (mr , ShouldCause )]
126
+ should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance (mr , ShouldNotCause )]
127
+
128
+ # Check all ShouldCause relations are present and no extra
129
+ expected_should_cause_relations = [ShouldCause ('X1' , 'Z' , [], dag ),
130
+ ShouldCause ('Z' , 'M' , [], dag ),
131
+ ShouldCause ('M' , 'Y' , [], dag ),
132
+ ShouldCause ('X2' , 'Z' , [], dag ),
133
+ ShouldCause ('X3' , 'M' , [], dag )]
134
+
135
+ extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations ]
136
+ missing_sc_relations = [escr for escr in expected_should_cause_relations if escr not in should_cause_relations ]
137
+
138
+ self .assertEqual (extra_sc_relations , [])
139
+ self .assertEqual (missing_sc_relations , [])
140
+
141
+ # Check all ShouldNotCause relations are present and no extra
142
+ expected_should_not_cause_relations = [ShouldNotCause ('X1' , 'X2' , [], dag ),
143
+ ShouldNotCause ('X1' , 'X3' , [], dag ),
144
+ ShouldNotCause ('X1' , 'M' , ['Z' ], dag ),
145
+ ShouldNotCause ('X1' , 'Y' , ['M' ], dag ),
146
+ ShouldNotCause ('X2' , 'X3' , [], dag ),
147
+ ShouldNotCause ('X2' , 'M' , ['Z' ], dag ),
148
+ ShouldNotCause ('X2' , 'Y' , ['M' ], dag ),
149
+ ShouldNotCause ('X3' , 'Y' , ['M' ], dag ),
150
+ ShouldNotCause ('Z' , 'Y' , ['M' ], dag ),
151
+ ShouldNotCause ('Z' , 'X3' , [], dag )]
152
+
153
+ extra_snc_relations = [sncr for sncr in should_not_cause_relations if sncr
154
+ not in expected_should_not_cause_relations ]
155
+ missing_snc_relations = [esncr for esncr in expected_should_not_cause_relations if esncr
156
+ not in should_not_cause_relations ]
157
+
158
+ self .assertEqual (extra_snc_relations , [])
159
+ self .assertEqual (missing_snc_relations , [])
0 commit comments