10
10
import argparse
11
11
import logging
12
12
import json
13
+ from multiprocessing import Pool
14
+
13
15
import networkx as nx
14
16
import pandas as pd
15
17
import numpy as np
16
- from multiprocessing import Pool
17
18
18
19
from causal_testing .specification .causal_specification import CausalDAG , Node
19
20
from causal_testing .data_collection .data_collector import ExperimentalDataCollector
@@ -268,7 +269,7 @@ def generate_metamorphic_relation(
268
269
269
270
270
271
def generate_metamorphic_relations (
271
- dag : CausalDAG , nodes_to_ignore : set = {}, threads : int = 0
272
+ dag : CausalDAG , nodes_to_ignore : set = {}, threads : int = 0 , nodes_to_test : set = None
272
273
) -> list [MetamorphicRelation ]:
273
274
"""Construct a list of metamorphic relations implied by the Causal DAG.
274
275
@@ -282,18 +283,21 @@ def generate_metamorphic_relations(
282
283
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
283
284
"""
284
285
286
+ if nodes_to_test is None :
287
+ nodes_to_test = dag .graph .nodes
288
+
285
289
if not threads :
286
290
metamorphic_relations = [
287
291
generate_metamorphic_relation (node_pair , dag , nodes_to_ignore )
288
- for node_pair in combinations (filter (lambda node : node not in nodes_to_ignore , dag . graph . nodes ), 2 )
292
+ for node_pair in combinations (filter (lambda node : node not in nodes_to_ignore , nodes_to_test ), 2 )
289
293
]
290
294
else :
291
295
with Pool (threads ) as pool :
292
- pool .starmap (
296
+ metamorphic_relations = pool .starmap (
293
297
generate_metamorphic_relation ,
294
298
map (
295
299
lambda node_pair : (node_pair , dag , nodes_to_ignore ),
296
- combinations (filter (lambda node : node not in nodes_to_ignore , dag . graph . nodes ), 2 ),
300
+ combinations (filter (lambda node : node not in nodes_to_ignore , nodes_to_test ), 2 ),
297
301
),
298
302
)
299
303
@@ -317,17 +321,28 @@ def generate_metamorphic_relations(
317
321
help = "Specify path where tests should be saved, normally a .json file." ,
318
322
required = True ,
319
323
)
324
+ parser .add_argument (
325
+ "--threads" , "-t" , type = int , help = "The number of parallel threads to use." , required = False , default = 0
326
+ )
320
327
parser .add_argument ("-i" , "--ignore-cycles" , action = "store_true" )
321
328
args = parser .parse_args ()
322
329
323
330
causal_dag = CausalDAG (args .dag_path , ignore_cycles = args .ignore_cycles )
324
331
332
+ nodes_to_test = set (
333
+ k for k , v in nx .get_node_attributes (causal_dag .graph , "test" , default = True ).items () if v == "True"
334
+ )
335
+
325
336
if not causal_dag .is_acyclic () and args .ignore_cycles :
326
337
logger .warning (
327
338
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
328
339
"Your causal test suite WILL NOT BE COMPLETE!"
329
340
)
330
- relations = generate_metamorphic_relations (causal_dag , nodes_to_ignore = set (causal_dag .cycle_nodes ()), threads = 20 )
341
+ relations = generate_metamorphic_relations (
342
+ causal_dag , nodes_to_test = nodes_to_test , nodes_to_ignore = set (causal_dag .cycle_nodes ()), threads = args .threads
343
+ )
344
+ else :
345
+ relations = generate_metamorphic_relations (causal_dag , nodes_to_test = nodes_to_test , threads = args .threads )
331
346
332
347
tests = [
333
348
relation .to_json_stub (skip = False )
0 commit comments