Skip to content

Commit 4fdf12b

Browse files
authored
Merge pull request #294 from CITCOM-project/ignore-cycles
Ignore cycles
2 parents de0a676 + 499fc54 commit 4fdf12b

File tree

8 files changed

+230
-69
lines changed

8 files changed

+230
-69
lines changed

.github/workflows/ci-tests-drafts.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
16-
python-version: ["3.9", "3.10", "3.11", "3.12"]
16+
python-version: ["3.10", "3.11", "3.12"]
1717
steps:
1818
- uses: actions/checkout@v4
1919
- name: Set up Python

.github/workflows/ci-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
matrix:
2020
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
21-
python-version: ["3.9", "3.10", "3.11", "3.12"]
21+
python-version: ["3.10", "3.11", "3.12"]
2222
steps:
2323
- uses: actions/checkout@v4
2424
- name: Set up Python

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
7070
data_paths = []
7171
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
7272

73-
def setup(self, scenario: Scenario, data=None):
73+
def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
7474
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
7575
self.scenario = scenario
7676
self._get_scenario_variables()
7777
self.scenario.setup_treatment_variables()
7878
self.causal_specification = CausalSpecification(
79-
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
79+
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles)
8080
)
8181
# Parse the JSON test plan
8282
with open(self.input_paths.json_path, encoding="utf-8") as f:

causal_testing/specification/causal_dag.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph):
130130
ensures it is acyclic. A CausalDAG must be specified as a dot file.
131131
"""
132132

133-
def __init__(self, dot_path: str = None, **attr):
133+
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
134134
super().__init__(**attr)
135135
if dot_path:
136136
with open(dot_path, "r", encoding="utf-8") as file:
@@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr):
144144
self.graph = nx.DiGraph()
145145

146146
if not self.is_acyclic():
147-
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
147+
if ignore_cycles:
148+
logger.warning(
149+
"Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution."
150+
)
151+
else:
152+
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
148153

149154
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
150155
"""
@@ -164,16 +169,17 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
164169
)
165170

166171
# (iii) Instrument and outcome do not share causes
167-
if any(
168-
(
169-
cause
170-
for cause in self.graph.nodes
171-
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
172-
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
173-
)
174-
):
175-
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
176172

173+
for cause in self.graph.nodes:
174+
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
175+
outcome_paths = (
176+
list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else []
177+
)
178+
instrument_paths = (
179+
list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else []
180+
)
181+
if len(instrument_paths) > 0 and len(outcome_paths) > 0:
182+
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
177183
return True
178184

179185
def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
@@ -188,12 +194,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
188194
if not self.is_acyclic():
189195
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
190196

197+
def cycle_nodes(self) -> list:
198+
"""Get the nodes involved in any cycles.
199+
:return: A list containing all nodes involved in a cycle.
200+
"""
201+
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]
202+
191203
def is_acyclic(self) -> bool:
192204
"""Checks if the graph is acyclic.
193205
194206
:return: True if acyclic, False otherwise.
195207
"""
196-
return not list(nx.simple_cycles(self.graph))
208+
return not self.cycle_nodes()
197209

198210
def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
199211
"""Convert the causal DAG to a proper back-door graph.
@@ -267,7 +279,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
267279
gback.graph.remove_edge(v1, v2)
268280
return gback
269281

270-
def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
282+
def direct_effect_adjustment_sets(
283+
self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None
284+
) -> list[set[str]]:
271285
"""
272286
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
273287
and outcomes for DIRECT causal effect.
@@ -278,12 +292,17 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
278292
2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal,
279293
backtracking algorithms to list the minimal vertex separators of a graph, 2013.
280294
281-
:param list[str] treatments: List of treatment names.
282-
:param list[str] outcomes: List of outcome names.
295+
:param treatments: List of treatment names.
296+
:param outcomes: List of outcome names.
297+
:param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the
298+
adjustment set.
283299
:return: A list of possible adjustment sets.
284300
:rtype: list[set[str]]
285301
"""
286302

303+
if nodes_to_ignore is None:
304+
nodes_to_ignore = []
305+
287306
indirect_graph = self.get_indirect_graph(treatments, outcomes)
288307
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
289308
gam = nx.moral_graph(ancestor_graph.graph)
@@ -295,7 +314,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
295314
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
296315
if set(outcomes) in min_seps:
297316
min_seps.remove(set(outcomes))
298-
return min_seps
317+
return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps)))
299318

300319
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
301320
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments

causal_testing/specification/metamorphic_relation.py

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import argparse
1111
import logging
1212
import json
13+
from multiprocessing import Pool
14+
1315
import networkx as nx
1416
import pandas as pd
1517
import numpy as np
@@ -214,46 +216,96 @@ def __str__(self):
214216
)
215217

216218

217-
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
219+
def generate_metamorphic_relation(
220+
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
221+
) -> MetamorphicRelation:
222+
"""Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
223+
be constructed (e.g. because every valid adjustment set contains a node to ignore).
224+
225+
:param node_pair: The pair of nodes to consider.
226+
:param dag: Causal DAG from which the metamorphic relations will be generated.
227+
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
228+
229+
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
230+
"""
231+
232+
if nodes_to_ignore is None:
233+
nodes_to_ignore = set()
234+
235+
(u, v) = node_pair
236+
metamorphic_relations = []
237+
238+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
239+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
240+
# Case 1: U --> ... --> V
241+
if u in nx.ancestors(dag.graph, v):
242+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
243+
if adj_sets:
244+
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
245+
246+
# Case 2: V --> ... --> U
247+
elif v in nx.ancestors(dag.graph, u):
248+
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
249+
if adj_sets:
250+
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))
251+
252+
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
253+
# Only make one MR since V _||_ U == U _||_ V
254+
else:
255+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
256+
if adj_sets:
257+
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
258+
259+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
260+
elif (u, v) in dag.graph.edges:
261+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
262+
if adj_sets:
263+
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
264+
else:
265+
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
266+
if adj_sets:
267+
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
268+
return metamorphic_relations
269+
270+
271+
def generate_metamorphic_relations(
272+
dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None
273+
) -> list[MetamorphicRelation]:
218274
"""Construct a list of metamorphic relations implied by the Causal DAG.
219275
220276
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
221277
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222278
223-
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
279+
:param dag: Causal DAG from which the metamorphic relations will be generated.
280+
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
281+
:param threads: Number of threads to use (if generating in parallel).
282+
:param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes).
283+
224284
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225285
"""
226-
metamorphic_relations = []
227-
for node_pair in combinations(dag.graph.nodes, 2):
228-
(u, v) = node_pair
229-
230-
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231-
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
232-
# Case 1: U --> ... --> V
233-
if u in nx.ancestors(dag.graph, v):
234-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
235-
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
236-
237-
# Case 2: V --> ... --> U
238-
elif v in nx.ancestors(dag.graph, u):
239-
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
240-
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
241-
242-
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243-
# Only make one MR since V _||_ U == U _||_ V
244-
else:
245-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
246-
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
247286

248-
# Create a ShouldCause relation for each edge (u, v) or (v, u)
249-
elif (u, v) in dag.graph.edges:
250-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
251-
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
252-
else:
253-
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
254-
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
287+
if nodes_to_ignore is None:
288+
nodes_to_ignore = {}
255289

256-
return metamorphic_relations
290+
if nodes_to_test is None:
291+
nodes_to_test = dag.graph.nodes
292+
293+
if not threads:
294+
metamorphic_relations = [
295+
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
296+
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
297+
]
298+
else:
299+
with Pool(threads) as pool:
300+
metamorphic_relations = pool.starmap(
301+
generate_metamorphic_relation,
302+
map(
303+
lambda node_pair: (node_pair, dag, nodes_to_ignore),
304+
combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2),
305+
),
306+
)
307+
308+
return [item for items in metamorphic_relations for item in items]
257309

258310

259311
if __name__ == "__main__": # pragma: no cover
@@ -273,10 +325,32 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273325
help="Specify path where tests should be saved, normally a .json file.",
274326
required=True,
275327
)
328+
parser.add_argument(
329+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
330+
)
331+
parser.add_argument("-i", "--ignore-cycles", action="store_true")
276332
args = parser.parse_args()
277333

278-
causal_dag = CausalDAG(args.dag_path)
279-
relations = generate_metamorphic_relations(causal_dag)
334+
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
335+
336+
dag_nodes_to_test = set(
337+
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
338+
)
339+
340+
if not causal_dag.is_acyclic() and args.ignore_cycles:
341+
logger.warning(
342+
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
343+
"Your causal test suite WILL NOT BE COMPLETE!"
344+
)
345+
relations = generate_metamorphic_relations(
346+
causal_dag,
347+
nodes_to_test=dag_nodes_to_test,
348+
nodes_to_ignore=set(causal_dag.cycle_nodes()),
349+
threads=args.threads,
350+
)
351+
else:
352+
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)
353+
280354
tests = [
281355
relation.to_json_stub(skip=False)
282356
for relation in relations

0 commit comments

Comments
 (0)