Skip to content

Commit 36a22e5

Browse files
committed
update: metamorphic_relation.py
1 parent de0a676 commit 36a22e5

File tree

1 file changed

+58
-25
lines changed

1 file changed

+58
-25
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -214,44 +214,63 @@ def __str__(self):
214214
)
215215

216216

217-
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
218-
"""Construct a list of metamorphic relations implied by the Causal DAG.
217+
def generate_metamorphic_relations(dag: CausalDAG, skip_ancestors: bool) -> list[MetamorphicRelation]:
218+
"""Construct a list of metamorphic relations based on the DAG or cyclic graph.
219219
220-
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
220+
If is_causal_dag is True, this list contains a ShouldCause relation for every edge, and a ShouldNotCause
221221
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
222+
If is_causal_dag is False, it skips checks assuming the graph is acyclic and works on general graphs with loops/cycles.
222223
223-
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
224+
:param CausalDAG dag: Graph from which the metamorphic relations will be generated.
225+
:param bool is_causal_dag: Specifies whether the input graph is a causal DAG or a cyclic graph.
224226
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
225227
"""
226228
metamorphic_relations = []
229+
227230
for node_pair in combinations(dag.graph.nodes, 2):
228231
(u, v) = node_pair
229232

230-
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
231-
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
232-
# Case 1: U --> ... --> V
233-
if u in nx.ancestors(dag.graph, v):
233+
# If the graph is a causal DAG, perform the ancestor checks
234+
if not skip_ancestors:
235+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
236+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
237+
# Case 1: U --> ... --> V
238+
if u in nx.ancestors(dag.graph, v):
239+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
240+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
241+
242+
# Case 2: V --> ... --> U
243+
elif v in nx.ancestors(dag.graph, u):
244+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
245+
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
246+
247+
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
248+
else:
249+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
250+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
251+
252+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
253+
elif (u, v) in dag.graph.edges:
234254
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
235-
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
236-
237-
# Case 2: V --> ... --> U
238-
elif v in nx.ancestors(dag.graph, u):
255+
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
256+
else:
239257
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
240-
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
258+
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
241259

242-
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
243-
# Only make one MR since V _||_ U == U _||_ V
244-
else:
245-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
260+
# If the graph may contain loops/cycles, skip the ancestor checks
261+
else:
262+
# Create a ShouldNotCause relation for nodes that are not directly connected
263+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
264+
adj_set = [] # No adjustment sets needed for cyclic graphs
246265
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
247266

248-
# Create a ShouldCause relation for each edge (u, v) or (v, u)
249-
elif (u, v) in dag.graph.edges:
250-
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
251-
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
252-
else:
253-
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
254-
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
267+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
268+
elif (u, v) in dag.graph.edges:
269+
adj_set = [] # No adjustment sets needed for cyclic graphs
270+
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
271+
else:
272+
adj_set = [] # No adjustment sets needed for cyclic graphs
273+
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
255274

256275
return metamorphic_relations
257276

@@ -273,10 +292,24 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
273292
help="Specify path where tests should be saved, normally a .json file.",
274293
required=True,
275294
)
295+
296+
parser.add_argument(
297+
"--skip_ancestors",
298+
"-c",
299+
action="store_true",
300+
default=False,
301+
help="Boolean flag to indicate if ancestors should be skipped. Default is False",
302+
required=False,
303+
)
304+
276305
args = parser.parse_args()
277306

278307
causal_dag = CausalDAG(args.dag_path)
279-
relations = generate_metamorphic_relations(causal_dag)
308+
relations = generate_metamorphic_relations(causal_dag, skip_ancestors=args.skip_ancestors)
309+
310+
if args.skip_ancestors:
311+
logger.warning("The 'skip_ancestors' variable is set to True, proceed with caution.")
312+
280313
tests = [
281314
relation.to_json_stub(skip=False)
282315
for relation in relations

0 commit comments

Comments
 (0)