10
10
import argparse
11
11
import logging
12
12
import json
13
+ from multiprocessing import Pool
14
+
13
15
import networkx as nx
14
16
import pandas as pd
15
17
import numpy as np
@@ -214,46 +216,96 @@ def __str__(self):
214
216
)
215
217
216
218
217
- def generate_metamorphic_relations (dag : CausalDAG ) -> list [MetamorphicRelation ]:
219
+ def generate_metamorphic_relation (
220
+ node_pair : tuple [str , str ], dag : CausalDAG , nodes_to_ignore : set = None
221
+ ) -> MetamorphicRelation :
222
+ """Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
223
+ be constructed (e.g. because every valid adjustment set contains a node to ignore).
224
+
225
+ :param node_pair: The pair of nodes to consider.
226
+ :param dag: Causal DAG from which the metamorphic relations will be generated.
227
+ :param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
228
+
229
+ :return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
230
+ """
231
+
232
+ if nodes_to_ignore is None :
233
+ nodes_to_ignore = set ()
234
+
235
+ (u , v ) = node_pair
236
+ metamorphic_relations = []
237
+
238
+ # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
239
+ if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
240
+ # Case 1: U --> ... --> V
241
+ if u in nx .ancestors (dag .graph , v ):
242
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
243
+ if adj_sets :
244
+ metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
245
+
246
+ # Case 2: V --> ... --> U
247
+ elif v in nx .ancestors (dag .graph , u ):
248
+ adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
249
+ if adj_sets :
250
+ metamorphic_relations .append (ShouldNotCause (v , u , list (adj_sets [0 ]), dag ))
251
+
252
+ # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
253
+ # Only make one MR since V _||_ U == U _||_ V
254
+ else :
255
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
256
+ if adj_sets :
257
+ metamorphic_relations .append (ShouldNotCause (u , v , list (adj_sets [0 ]), dag ))
258
+
259
+ # Create a ShouldCause relation for each edge (u, v) or (v, u)
260
+ elif (u , v ) in dag .graph .edges :
261
+ adj_sets = dag .direct_effect_adjustment_sets ([u ], [v ], nodes_to_ignore = nodes_to_ignore )
262
+ if adj_sets :
263
+ metamorphic_relations .append (ShouldCause (u , v , list (adj_sets [0 ]), dag ))
264
+ else :
265
+ adj_sets = dag .direct_effect_adjustment_sets ([v ], [u ], nodes_to_ignore = nodes_to_ignore )
266
+ if adj_sets :
267
+ metamorphic_relations .append (ShouldCause (v , u , list (adj_sets [0 ]), dag ))
268
+ return metamorphic_relations
269
+
270
+
271
+ def generate_metamorphic_relations (
272
+ dag : CausalDAG , nodes_to_ignore : set = None , threads : int = 0 , nodes_to_test : set = None
273
+ ) -> list [MetamorphicRelation ]:
218
274
"""Construct a list of metamorphic relations implied by the Causal DAG.
219
275
220
276
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
221
277
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222
278
223
- :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
279
+ :param dag: Causal DAG from which the metamorphic relations will be generated.
280
+ :param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
281
+ :param threads: Number of threads to use (if generating in parallel).
282
+ :param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes).
283
+
224
284
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225
285
"""
226
- metamorphic_relations = []
227
- for node_pair in combinations (dag .graph .nodes , 2 ):
228
- (u , v ) = node_pair
229
-
230
- # Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231
- if ((u , v ) not in dag .graph .edges ) and ((v , u ) not in dag .graph .edges ):
232
- # Case 1: U --> ... --> V
233
- if u in nx .ancestors (dag .graph , v ):
234
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
235
- metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
236
-
237
- # Case 2: V --> ... --> U
238
- elif v in nx .ancestors (dag .graph , u ):
239
- adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
240
- metamorphic_relations .append (ShouldNotCause (v , u , adj_set , dag ))
241
-
242
- # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243
- # Only make one MR since V _||_ U == U _||_ V
244
- else :
245
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
246
- metamorphic_relations .append (ShouldNotCause (u , v , adj_set , dag ))
247
286
248
- # Create a ShouldCause relation for each edge (u, v) or (v, u)
249
- elif (u , v ) in dag .graph .edges :
250
- adj_set = list (dag .direct_effect_adjustment_sets ([u ], [v ])[0 ])
251
- metamorphic_relations .append (ShouldCause (u , v , adj_set , dag ))
252
- else :
253
- adj_set = list (dag .direct_effect_adjustment_sets ([v ], [u ])[0 ])
254
- metamorphic_relations .append (ShouldCause (v , u , adj_set , dag ))
287
+ if nodes_to_ignore is None :
288
+ nodes_to_ignore = {}
255
289
256
- return metamorphic_relations
290
+ if nodes_to_test is None :
291
+ nodes_to_test = dag .graph .nodes
292
+
293
+ if not threads :
294
+ metamorphic_relations = [
295
+ generate_metamorphic_relation (node_pair , dag , nodes_to_ignore )
296
+ for node_pair in combinations (filter (lambda node : node not in nodes_to_ignore , nodes_to_test ), 2 )
297
+ ]
298
+ else :
299
+ with Pool (threads ) as pool :
300
+ metamorphic_relations = pool .starmap (
301
+ generate_metamorphic_relation ,
302
+ map (
303
+ lambda node_pair : (node_pair , dag , nodes_to_ignore ),
304
+ combinations (filter (lambda node : node not in nodes_to_ignore , nodes_to_test ), 2 ),
305
+ ),
306
+ )
307
+
308
+ return [item for items in metamorphic_relations for item in items ]
257
309
258
310
259
311
if __name__ == "__main__" : # pragma: no cover
@@ -273,10 +325,32 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273
325
help = "Specify path where tests should be saved, normally a .json file." ,
274
326
required = True ,
275
327
)
328
+ parser .add_argument (
329
+ "--threads" , "-t" , type = int , help = "The number of parallel threads to use." , required = False , default = 0
330
+ )
331
+ parser .add_argument ("-i" , "--ignore-cycles" , action = "store_true" )
276
332
args = parser .parse_args ()
277
333
278
- causal_dag = CausalDAG (args .dag_path )
279
- relations = generate_metamorphic_relations (causal_dag )
334
+ causal_dag = CausalDAG (args .dag_path , ignore_cycles = args .ignore_cycles )
335
+
336
+ dag_nodes_to_test = set (
337
+ k for k , v in nx .get_node_attributes (causal_dag .graph , "test" , default = True ).items () if v == "True"
338
+ )
339
+
340
+ if not causal_dag .is_acyclic () and args .ignore_cycles :
341
+ logger .warning (
342
+ "Ignoring cycles by removing causal tests that reference any node within a cycle. "
343
+ "Your causal test suite WILL NOT BE COMPLETE!"
344
+ )
345
+ relations = generate_metamorphic_relations (
346
+ causal_dag ,
347
+ nodes_to_test = dag_nodes_to_test ,
348
+ nodes_to_ignore = set (causal_dag .cycle_nodes ()),
349
+ threads = args .threads ,
350
+ )
351
+ else :
352
+ relations = generate_metamorphic_relations (causal_dag , nodes_to_test = dag_nodes_to_test , threads = args .threads )
353
+
280
354
tests = [
281
355
relation .to_json_stub (skip = False )
282
356
for relation in relations
0 commit comments