Skip to content

Commit cfdd516

Browse files
committed
Option to ignore cycles when generating metamorphic relations
1 parent de0a676 commit cfdd516

File tree

2 files changed

+106
-38
lines changed

2 files changed

+106
-38
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph):
130130
ensures it is acyclic. A CausalDAG must be specified as a dot file.
131131
"""
132132

133-
def __init__(self, dot_path: str = None, **attr):
133+
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
134134
super().__init__(**attr)
135135
if dot_path:
136136
with open(dot_path, "r", encoding="utf-8") as file:
@@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr):
144144
self.graph = nx.DiGraph()
145145

146146
if not self.is_acyclic():
147-
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
147+
if ignore_cycles:
148+
logger.warning(
149+
"Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution."
150+
)
151+
else:
152+
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
148153

149154
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
150155
"""
@@ -188,12 +193,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
188193
if not self.is_acyclic():
189194
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
190195

196+
def cycle_nodes(self) -> list:
197+
"""Get the nodes involved in any cycles.
198+
:return: A list containing all nodes involved in a cycle.
199+
"""
200+
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]
201+
191202
def is_acyclic(self) -> bool:
192203
"""Checks if the graph is acyclic.
193204
194205
:return: True if acyclic, False otherwise.
195206
"""
196-
return not list(nx.simple_cycles(self.graph))
207+
return not self.cycle_nodes()
197208

198209
def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
199210
"""Convert the causal DAG to a proper back-door graph.
@@ -267,7 +278,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
267278
gback.graph.remove_edge(v1, v2)
268279
return gback
269280

270-
def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
281+
def direct_effect_adjustment_sets(
282+
self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None
283+
) -> list[set[str]]:
271284
"""
272285
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
273286
and outcomes for DIRECT causal effect.
@@ -284,6 +297,9 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
284297
:rtype: list[set[str]]
285298
"""
286299

300+
if nodes_to_ignore is None:
301+
nodes_to_ignore = []
302+
287303
indirect_graph = self.get_indirect_graph(treatments, outcomes)
288304
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
289305
gam = nx.moral_graph(ancestor_graph.graph)
@@ -295,7 +311,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
295311
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
296312
if set(outcomes) in min_seps:
297313
min_seps.remove(set(outcomes))
298-
return min_seps
314+
return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps)))
299315

300316
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
301317
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments

causal_testing/specification/metamorphic_relation.py

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import networkx as nx
1414
import pandas as pd
1515
import numpy as np
16+
from multiprocessing import Pool
1617

1718
from causal_testing.specification.causal_specification import CausalDAG, Node
1819
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
@@ -214,46 +215,89 @@ def __str__(self):
214215
)
215216

216217

217-
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
218+
def generate_metamorphic_relation(
219+
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
220+
) -> MetamorphicRelation:
221+
"""Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
222+
be constructed (e.g. because every valid adjustment set contains a node to ignore).
223+
224+
:param node_pair: The pair of nodes to consider.
225+
:param dag: Causal DAG from which the metamorphic relations will be generated.
226+
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
227+
228+
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
229+
"""
230+
231+
if nodes_to_ignore is None:
232+
nodes_to_ignore = set()
233+
234+
(u, v) = node_pair
235+
metamorphic_relations = []
236+
237+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
238+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
239+
# Case 1: U --> ... --> V
240+
if u in nx.ancestors(dag.graph, v):
241+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
242+
if adj_sets:
243+
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
244+
245+
# Case 2: V --> ... --> U
246+
elif v in nx.ancestors(dag.graph, u):
247+
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
248+
if adj_sets:
249+
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))
250+
251+
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
252+
# Only make one MR since V _||_ U == U _||_ V
253+
else:
254+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
255+
if adj_sets:
256+
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
257+
258+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
259+
elif (u, v) in dag.graph.edges:
260+
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
261+
if adj_sets:
262+
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
263+
else:
264+
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
265+
if adj_sets:
266+
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
267+
return metamorphic_relations
268+
269+
270+
def generate_metamorphic_relations(
271+
dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0
272+
) -> list[MetamorphicRelation]:
218273
"""Construct a list of metamorphic relations implied by the Causal DAG.
219274
220275
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
221276
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222277
223-
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
278+
:param dag: Causal DAG from which the metamorphic relations will be generated.
279+
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
280+
:param threads: Number of threads to use (if generating in parallel).
281+
224282
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225283
"""
226-
metamorphic_relations = []
227-
for node_pair in combinations(dag.graph.nodes, 2):
228-
(u, v) = node_pair
229-
230-
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231-
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
232-
# Case 1: U --> ... --> V
233-
if u in nx.ancestors(dag.graph, v):
234-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
235-
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
236-
237-
# Case 2: V --> ... --> U
238-
elif v in nx.ancestors(dag.graph, u):
239-
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
240-
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
241-
242-
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243-
# Only make one MR since V _||_ U == U _||_ V
244-
else:
245-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
246-
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
247284

248-
# Create a ShouldCause relation for each edge (u, v) or (v, u)
249-
elif (u, v) in dag.graph.edges:
250-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
251-
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
252-
else:
253-
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
254-
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
285+
if not threads:
286+
metamorphic_relations = [
287+
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
288+
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2)
289+
]
290+
else:
291+
with Pool(threads) as pool:
292+
pool.starmap(
293+
generate_metamorphic_relation,
294+
map(
295+
lambda node_pair: (node_pair, dag, nodes_to_ignore),
296+
combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2),
297+
),
298+
)
255299

256-
return metamorphic_relations
300+
return [item for items in metamorphic_relations for item in items]
257301

258302

259303
if __name__ == "__main__": # pragma: no cover
@@ -273,10 +317,18 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273317
help="Specify path where tests should be saved, normally a .json file.",
274318
required=True,
275319
)
320+
parser.add_argument("-i", "--ignore-cycles", action="store_true")
276321
args = parser.parse_args()
277322

278-
causal_dag = CausalDAG(args.dag_path)
279-
relations = generate_metamorphic_relations(causal_dag)
323+
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
324+
325+
if not causal_dag.is_acyclic() and args.ignore_cycles:
326+
logger.warning(
327+
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
328+
"Your causal test suite WILL NOT BE COMPLETE!"
329+
)
330+
relations = generate_metamorphic_relations(causal_dag, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=20)
331+
280332
tests = [
281333
relation.to_json_stub(skip=False)
282334
for relation in relations

0 commit comments

Comments
 (0)