Skip to content

Commit da3fb4d

Browse files
committed
Capability to specify nodes to test in causal dag
1 parent cfdd516 commit da3fb4d

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import argparse
1111
import logging
1212
import json
13+
from multiprocessing import Pool
14+
1315
import networkx as nx
1416
import pandas as pd
1517
import numpy as np
16-
from multiprocessing import Pool
1718

1819
from causal_testing.specification.causal_specification import CausalDAG, Node
1920
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
@@ -268,7 +269,7 @@ def generate_metamorphic_relation(
268269

269270

270271
def generate_metamorphic_relations(
271-
dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0
272+
dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0, nodes_to_test: set = None
272273
) -> list[MetamorphicRelation]:
273274
"""Construct a list of metamorphic relations implied by the Causal DAG.
274275
@@ -282,18 +283,21 @@ def generate_metamorphic_relations(
282283
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
283284
"""
284285

286+
if nodes_to_test is None:
287+
nodes_to_test = dag.graph.nodes
288+
285289
if not threads:
286290
metamorphic_relations = [
287291
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
288-
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2)
292+
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
289293
]
290294
else:
291295
with Pool(threads) as pool:
292-
pool.starmap(
296+
metamorphic_relations = pool.starmap(
293297
generate_metamorphic_relation,
294298
map(
295299
lambda node_pair: (node_pair, dag, nodes_to_ignore),
296-
combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2),
300+
combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2),
297301
),
298302
)
299303

@@ -317,17 +321,28 @@ def generate_metamorphic_relations(
317321
help="Specify path where tests should be saved, normally a .json file.",
318322
required=True,
319323
)
324+
parser.add_argument(
325+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
326+
)
320327
parser.add_argument("-i", "--ignore-cycles", action="store_true")
321328
args = parser.parse_args()
322329

323330
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
324331

332+
nodes_to_test = set(
333+
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
334+
)
335+
325336
if not causal_dag.is_acyclic() and args.ignore_cycles:
326337
logger.warning(
327338
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
328339
"Your causal test suite WILL NOT BE COMPLETE!"
329340
)
330-
relations = generate_metamorphic_relations(causal_dag, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=20)
341+
relations = generate_metamorphic_relations(
342+
causal_dag, nodes_to_test=nodes_to_test, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=args.threads
343+
)
344+
else:
345+
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=nodes_to_test, threads=args.threads)
331346

332347
tests = [
333348
relation.to_json_stub(skip=False)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"fitter~=1.7",
2020
"lifelines~=0.29.0",
2121
"lhsmdu~=1.1",
22-
"networkx~=2.6",
22+
"networkx~=3.4",
2323
"numpy~=1.26",
2424
"pandas>=2.1",
2525
"scikit_learn~=1.4",

0 commit comments

Comments
 (0)