13
13
import networkx as nx
14
14
import pandas as pd
15
15
import numpy as np
16
+ from multiprocessing import Pool
16
17
17
18
from causal_testing .specification .causal_specification import CausalDAG , Node
18
19
from causal_testing .data_collection .data_collector import ExperimentalDataCollector
@@ -214,46 +215,89 @@ def __str__(self):
214
215
)
215
216
216
217
217
- def generate_metamorphic_relations (dag : CausalDAG ) -> list [MetamorphicRelation ]:
218
+ def generate_metamorphic_relation (
219
+ node_pair : tuple [str , str ], dag : CausalDAG , nodes_to_ignore : set = None
220
+ ) -> MetamorphicRelation :
221
+ """Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
222
+ be constructed (e.g. because every valid adjustment set contains a node to ignore).
223
+
224
+ :param node_pair: The pair of nodes to consider.
225
+ :param dag: Causal DAG from which the metamorphic relations will be generated.
226
+ :param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
227
+
228
+ :return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
229
+ """
230
+
231
+ if nodes_to_ignore is None :
232
+ nodes_to_ignore = set ()
233
+
234
+ (u , v ) = node_pair
235
+ metamorphic_relations = []
236
+
237
+ # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
238
+ if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
239
+ # Case 1: U --> ... --> V
240
+ if u in nx .ancestors (dag .graph , v ):
241
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
242
+ if adj_sets :
243
+ metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
244
+
245
+ # Case 2: V --> ... --> U
246
+ elif v in nx .ancestors (dag .graph , u ):
247
+ adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
248
+ if adj_sets :
249
+ metamorphic_relations .append (ShouldNotCause (v , u , list (adj_sets [0 ]), dag ))
250
+
251
+ # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
252
+ # Only make one MR since V _||_ U == U _||_ V
253
+ else :
254
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
255
+ if adj_sets :
256
+ metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
257
+
258
+ # Create a ShouldCause relation for each edge (u, v) or (v, u)
259
+ elif (u , v ) in dag .graph .edges :
260
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
261
+ if adj_sets :
262
+ metamorphic_relations .append (ShouldCause (u , v , list (adj_sets [0 ]), dag ))
263
+ else :
264
+ adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
265
+ if adj_sets :
266
+ metamorphic_relations .append (ShouldCause (v , u , list (adj_sets [0 ]), dag ))
267
+ return metamorphic_relations
268
+
269
+
270
+ def generate_metamorphic_relations (
271
+ dag : CausalDAG , nodes_to_ignore : set = {}, threads : int = 0
272
+ ) -> list [MetamorphicRelation ]:
218
273
"""Construct a list of metamorphic relations implied by the Causal DAG.
219
274
220
275
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
221
276
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222
277
223
- :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
278
+ :param dag: Causal DAG from which the metamorphic relations will be generated.
279
+ :param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
280
+ :param threads: Number of threads to use (if generating in parallel).
281
+
224
282
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225
283
"""
226
- metamorphic_relations = []
227
- for node_pair in combinations (dag .graph .nodes , 2 ):
228
- (u , v ) = node_pair
229
-
230
- # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231
- if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
232
- # Case 1: U --> ... --> V
233
- if u in nx .ancestors (dag .graph , v ):
234
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
235
- metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
236
-
237
- # Case 2: V --> ... --> U
238
- elif v in nx .ancestors (dag .graph , u ):
239
- adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
240
- metamorphic_relations .append (ShouldNotCause (v , u , adj_set , dag ))
241
-
242
- # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243
- # Only make one MR since V _||_ U == U _||_ V
244
- else :
245
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
246
- metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
247
284
248
- # Create a ShouldCause relation for each edge (u, v) or (v, u)
249
- elif (u , v ) in dag .graph .edges :
250
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
251
- metamorphic_relations .append (ShouldCause (u , v , adj_set , dag ))
252
- else :
253
- adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
254
- metamorphic_relations .append (ShouldCause (v , u , adj_set , dag ))
285
+ if not threads :
286
+ metamorphic_relations = [
287
+ generate_metamorphic_relation (node_pair , dag , nodes_to_ignore )
288
+ for node_pair in combinations (filter (lambda node : node not in nodes_to_ignore , dag .graph .nodes ), 2 )
289
+ ]
290
+ else :
291
+ with Pool (threads ) as pool :
292
+ pool .starmap (
293
+ generate_metamorphic_relation ,
294
+ map (
295
+ lambda node_pair : (node_pair , dag , nodes_to_ignore ),
296
+ combinations (filter (lambda node : node not in nodes_to_ignore , dag .graph .nodes ), 2 ),
297
+ ),
298
+ )
255
299
256
- return metamorphic_relations
300
+ return [ item for items in metamorphic_relations for item in items ]
257
301
258
302
259
303
if __name__ == "__main__" : # pragma: no cover
@@ -273,10 +317,18 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273
317
help = "Specify path where tests should be saved, normally a .json file." ,
274
318
required = True ,
275
319
)
320
+ parser .add_argument ("-i" , "--ignore-cycles" , action = "store_true" )
276
321
args = parser .parse_args ()
277
322
278
- causal_dag = CausalDAG (args .dag_path )
279
- relations = generate_metamorphic_relations (causal_dag )
323
+ causal_dag = CausalDAG (args .dag_path , ignore_cycles = args .ignore_cycles )
324
+
325
+ if not causal_dag .is_acyclic () and args .ignore_cycles :
326
+ logger .warning (
327
+ "Ignoring cycles by removing causal tests that reference any node within a cycle. "
328
+ "Your causal test suite WILL NOT BE COMPLETE!"
329
+ )
330
+ relations = generate_metamorphic_relations (causal_dag , nodes_to_ignore = set (causal_dag .cycle_nodes ()), threads = 20 )
331
+
280
332
tests = [
281
333
relation .to_json_stub (skip = False )
282
334
for relation in relations
0 commit comments