13
13
from multiprocessing import Pool
14
14
15
15
import networkx as nx
16
- import pandas as pd
17
- import numpy as np
18
16
19
17
from causal_testing .specification .causal_specification import CausalDAG , Node
20
- from causal_testing .data_collection . data_collector import ExperimentalDataCollector
18
+ from causal_testing .testing . base_test_case import BaseTestCase
21
19
22
20
logger = logging .getLogger (__name__ )
23
21
26
24
class MetamorphicRelation :
27
25
"""Class representing a metamorphic relation."""
28
26
29
- treatment_var : Node
30
- output_var : Node
27
+ base_test_case : BaseTestCase
31
28
adjustment_vars : Iterable [Node ]
32
29
dag : CausalDAG
33
30
tests : Iterable = None
34
31
35
- def generate_follow_up (self , n_tests : int , min_val : float , max_val : float , seed : int = 0 ):
36
- """Generate numerical follow-up input configurations."""
37
- np .random .seed (seed )
38
-
39
- # Get set of variables to change, excluding the treatment itself
40
- variables_to_change = {node for node in self .dag .graph .nodes if self .dag .graph .in_degree (node ) == 0 }
41
- if self .adjustment_vars :
42
- variables_to_change |= set (self .adjustment_vars )
43
- if self .treatment_var in variables_to_change :
44
- variables_to_change .remove (self .treatment_var )
45
-
46
- # Assign random numerical values to the variables to change
47
- test_inputs = pd .DataFrame (
48
- np .random .randint (min_val , max_val , size = (n_tests , len (variables_to_change ))),
49
- columns = sorted (variables_to_change ),
50
- )
51
-
52
- # Enumerate the possible source, follow-up pairs for the treatment
53
- candidate_source_follow_up_pairs = np .array (list (combinations (range (int (min_val ), int (max_val + 1 )), 2 )))
54
-
55
- # Sample without replacement from the possible source, follow-up pairs
56
- sampled_source_follow_up_indices = np .random .choice (
57
- candidate_source_follow_up_pairs .shape [0 ], n_tests , replace = False
58
- )
59
-
60
- follow_up_input = f"{ self .treatment_var } '"
61
- source_follow_up_test_inputs = pd .DataFrame (
62
- candidate_source_follow_up_pairs [sampled_source_follow_up_indices ],
63
- columns = sorted ([self .treatment_var ] + [follow_up_input ]),
64
- )
65
- self .tests = [
66
- MetamorphicTest (
67
- source_inputs ,
68
- follow_up_inputs ,
69
- other_inputs ,
70
- self .output_var ,
71
- str (self ),
72
- )
73
- for source_inputs , follow_up_inputs , other_inputs in zip (
74
- source_follow_up_test_inputs [[self .treatment_var ]].to_dict (orient = "records" ),
75
- source_follow_up_test_inputs [[follow_up_input ]]
76
- .rename (columns = {follow_up_input : self .treatment_var })
77
- .to_dict (orient = "records" ),
78
- (
79
- test_inputs .to_dict (orient = "records" )
80
- if not test_inputs .empty
81
- else [{}] * len (source_follow_up_test_inputs )
82
- ),
83
- )
84
- ]
85
-
86
- def execute_tests (self , data_collector : ExperimentalDataCollector ):
87
- """Execute the generated list of metamorphic tests, returning a dictionary of tests that pass and fail.
88
-
89
- :param data_collector: An experimental data collector for the system-under-test.
90
- """
91
- test_results = {"pass" : [], "fail" : []}
92
- for metamorphic_test in self .tests :
93
- # Update the control and treatment configuration to take generated values for source and follow-up tests
94
- control_input_config = metamorphic_test .source_inputs | metamorphic_test .other_inputs
95
- treatment_input_config = metamorphic_test .follow_up_inputs | metamorphic_test .other_inputs
96
- data_collector .control_input_configuration = control_input_config
97
- data_collector .treatment_input_configuration = treatment_input_config
98
- metamorphic_test_results_df = data_collector .collect_data ()
99
-
100
- # Apply assertion to control and treatment outputs
101
- control_output = metamorphic_test_results_df .loc ["control_0" ][metamorphic_test .output ]
102
- treatment_output = metamorphic_test_results_df .loc ["treatment_0" ][metamorphic_test .output ]
103
-
104
- if not self .assertion (control_output , treatment_output ):
105
- test_results ["fail" ].append (metamorphic_test )
106
- else :
107
- test_results ["pass" ].append (metamorphic_test )
108
- return test_results
109
-
110
- @abstractmethod
111
- def assertion (self , source_output , follow_up_output ):
112
- """An assertion that should be applied to an individual metamorphic test run."""
113
-
114
32
@abstractmethod
115
33
def to_json_stub (self , skip = True ) -> dict :
116
34
"""Convert to a JSON frontend stub string for user customisation"""
@@ -123,10 +41,11 @@ def test_oracle(self, test_results):
123
41
124
42
def __eq__ (self , other ):
125
43
same_type = self .__class__ == other .__class__
126
- same_treatment = self .treatment_var == other .treatment_var
127
- same_output = self .output_var == other .output_var
44
+ same_treatment = self .base_test_case .treatment_variable == other .base_test_case .treatment_variable
45
+ same_outcome = self .base_test_case .outcome_variable == other .base_test_case .outcome_variable
46
+ same_effect = self .base_test_case .effect == other .base_test_case .effect
128
47
same_adjustment_set = set (self .adjustment_vars ) == set (other .adjustment_vars )
129
- return same_type and same_treatment and same_output and same_adjustment_set
48
+ return same_type and same_treatment and same_outcome and same_effect and same_adjustment_set
130
49
131
50
132
51
class ShouldCause (MetamorphicRelation ):
@@ -149,14 +68,14 @@ def to_json_stub(self, skip=True) -> dict:
149
68
"estimator" : "LinearRegressionEstimator" ,
150
69
"estimate_type" : "coefficient" ,
151
70
"effect" : "direct" ,
152
- "mutations" : [self .treatment_var ],
153
- "expected_effect" : {self .output_var : "SomeEffect" },
154
- "formula" : f"{ self .output_var } ~ { ' + ' .join ([self .treatment_var ] + self .adjustment_vars )} " ,
71
+ "mutations" : [self .base_test_case . treatment_variable ],
72
+ "expected_effect" : {self .base_test_case . outcome_variable : "SomeEffect" },
73
+ "formula" : f"{ self .base_test_case . outcome_variable } ~ { ' + ' .join ([self .base_test_case . treatment_variable ] + self .adjustment_vars )} " ,
155
74
"skip" : skip ,
156
75
}
157
76
158
77
def __str__ (self ):
159
- formatted_str = f"{ self .treatment_var } --> { self .output_var } "
78
+ formatted_str = f"{ self .base_test_case . treatment_variable } --> { self .base_test_case . outcome_variable } "
160
79
if self .adjustment_vars :
161
80
formatted_str += f" | { self .adjustment_vars } "
162
81
return formatted_str
@@ -182,40 +101,20 @@ def to_json_stub(self, skip=True) -> dict:
182
101
"estimator" : "LinearRegressionEstimator" ,
183
102
"estimate_type" : "coefficient" ,
184
103
"effect" : "direct" ,
185
- "mutations" : [self .treatment_var ],
186
- "expected_effect" : {self .output_var : "NoEffect" },
187
- "formula" : f"{ self .output_var } ~ { ' + ' .join ([self .treatment_var ] + self .adjustment_vars )} " ,
104
+ "mutations" : [self .base_test_case . treatment_variable ],
105
+ "expected_effect" : {self .base_test_case . outcome_variable : "NoEffect" },
106
+ "formula" : f"{ self .base_test_case . outcome_variable } ~ { ' + ' .join ([self .base_test_case . treatment_variable ] + self .adjustment_vars )} " ,
188
107
"alpha" : 0.05 ,
189
108
"skip" : skip ,
190
109
}
191
110
192
111
def __str__ (self ):
193
- formatted_str = f"{ self .treatment_var } _||_ { self .output_var } "
112
+ formatted_str = f"{ self .base_test_case . treatment_variable } _||_ { self .base_test_case . outcome_variable } "
194
113
if self .adjustment_vars :
195
114
formatted_str += f" | { self .adjustment_vars } "
196
115
return formatted_str
197
116
198
117
199
- @dataclass (order = True )
200
- class MetamorphicTest :
201
- """Class representing a metamorphic test case."""
202
-
203
- source_inputs : dict
204
- follow_up_inputs : dict
205
- other_inputs : dict
206
- output : str
207
- relation : str
208
-
209
- def __str__ (self ):
210
- return (
211
- f"Source inputs: { self .source_inputs } \n "
212
- f"Follow-up inputs: { self .follow_up_inputs } \n "
213
- f"Other inputs: { self .other_inputs } \n "
214
- f"Output: { self .output } "
215
- f"Metamorphic Relation: { self .relation } "
216
- )
217
-
218
-
219
118
def generate_metamorphic_relation (
220
119
node_pair : tuple [str , str ], dag : CausalDAG , nodes_to_ignore : set = None
221
120
) -> MetamorphicRelation :
@@ -241,30 +140,30 @@ def generate_metamorphic_relation(
241
140
if u in nx .ancestors (dag .graph , v ):
242
141
adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
243
142
if adj_sets :
244
- metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
143
+ metamorphic_relations .append (ShouldNotCause (BaseTestCase ( u , v ) , list (adj_sets [0 ]), dag ))
245
144
246
145
# Case 2: V --> ... --> U
247
146
elif v in nx .ancestors (dag .graph , u ):
248
147
adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
249
148
if adj_sets :
250
- metamorphic_relations .append (ShouldNotCause (v , u , list (adj_sets [0 ]), dag ))
149
+ metamorphic_relations .append (ShouldNotCause (BaseTestCase ( v , u ) , list (adj_sets [0 ]), dag ))
251
150
252
151
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
253
152
# Only make one MR since V _||_ U == U _||_ V
254
153
else :
255
154
adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
256
155
if adj_sets :
257
- metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
156
+ metamorphic_relations .append (ShouldNotCause (BaseTestCase ( u , v ) , list (adj_sets [0 ]), dag ))
258
157
259
158
# Create a ShouldCause relation for each edge (u, v) or (v, u)
260
159
elif (u , v ) in dag .graph .edges :
261
160
adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
262
161
if adj_sets :
263
- metamorphic_relations .append (ShouldCause (u , v , list (adj_sets [0 ]), dag ))
162
+ metamorphic_relations .append (ShouldCause (BaseTestCase ( u , v ) , list (adj_sets [0 ]), dag ))
264
163
else :
265
164
adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
266
165
if adj_sets :
267
- metamorphic_relations .append (ShouldCause (v , u , list (adj_sets [0 ]), dag ))
166
+ metamorphic_relations .append (ShouldCause (BaseTestCase ( v , u ) , list (adj_sets [0 ]), dag ))
268
167
return metamorphic_relations
269
168
270
169
0 commit comments