Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions causal_testing/specification/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,44 +214,63 @@
)


def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
"""Construct a list of metamorphic relations implied by the Causal DAG.
def generate_metamorphic_relations(dag: CausalDAG, skip_ancestors: bool) -> list[MetamorphicRelation]:
"""Construct a list of metamorphic relations based on the DAG or cyclic graph.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure if actually relevant but in the literature a causal graph with loops is called a DCG (or you could just refer to either as a "causal graph". Pedantic suggestion though


This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
If is_causal_dag is True, this list contains a ShouldCause relation for every edge, and a ShouldNotCause
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
If is_causal_dag is False, it skips checks assuming the graph is acyclic and works on general graphs with loops/cycles.

:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
:param CausalDAG dag: Graph from which the metamorphic relations will be generated.
:param bool is_causal_dag: Specifies whether the input graph is a causal DAG or a cyclic graph.
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""
metamorphic_relations = []

for node_pair in combinations(dag.graph.nodes, 2):
(u, v) = node_pair

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
# If the graph is a causal DAG, perform the ancestor checks
if not skip_ancestors:
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
else:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
else:
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
# If the graph may contain loops/cycles, skip the ancestor checks
else:
# Create a ShouldNotCause relation for nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
adj_set = [] # No adjustment sets needed for cyclic graphs

Check warning on line 264 in causal_testing/specification/metamorphic_relation.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/metamorphic_relation.py#L263-L264

Added lines #L263 - L264 were not covered by tests
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
else:
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = [] # No adjustment sets needed for cyclic graphs
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))

Check warning on line 270 in causal_testing/specification/metamorphic_relation.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/metamorphic_relation.py#L268-L270

Added lines #L268 - L270 were not covered by tests
else:
adj_set = [] # No adjustment sets needed for cyclic graphs
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))

Check warning on line 273 in causal_testing/specification/metamorphic_relation.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/metamorphic_relation.py#L272-L273

Added lines #L272 - L273 were not covered by tests

return metamorphic_relations

Expand All @@ -273,10 +292,24 @@
help="Specify path where tests should be saved, normally a .json file.",
required=True,
)

parser.add_argument(
"--skip_ancestors",
"-c",
action="store_true",
default=False,
help="Boolean flag to indicate if ancestors should be skipped. Default is False",
required=False,
)

args = parser.parse_args()

causal_dag = CausalDAG(args.dag_path)
relations = generate_metamorphic_relations(causal_dag)
relations = generate_metamorphic_relations(causal_dag, skip_ancestors=args.skip_ancestors)

if args.skip_ancestors:
logger.warning("The 'skip_ancestors' variable is set to True, proceed with caution.")

tests = [
relation.to_json_stub(skip=False)
for relation in relations
Expand Down
2 changes: 1 addition & 1 deletion tests/specification_tests/test_metamorphic_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_should_cause_metamorphic_relation_missing_relationship(self):
def test_all_metamorphic_relations_implied_by_dag(self):
dag = CausalDAG(self.dag_dot_path)
dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator
metamorphic_relations = generate_metamorphic_relations(dag)
metamorphic_relations = generate_metamorphic_relations(dag, skip_ancestors=False)
should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)]
should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)]

Expand Down
Loading