diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index 9d8c8afb..3fb5de48 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -214,44 +214,65 @@ def __str__(self): ) -def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: - """Construct a list of metamorphic relations implied by the Causal DAG. +def generate_metamorphic_relations(dag: CausalDAG, skip_ancestors: bool = False) -> list[MetamorphicRelation]: + """Construct a list of metamorphic relations based on the DAG or cyclic graph. - This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause - relation for every (minimal) conditional independence relation implied by the structure of the DAG. + If skip_ancestors is False, this list contains a ShouldCause relation for every edge, and a + ShouldNotCause relation for every (minimal) conditional independence relation implied by + the structure of the DAG. If skip_ancestors is True, it skips checks assuming the graph + is acyclic and works on general graphs with loops/cycles. - :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated. + :param CausalDAG dag: Graph from which the metamorphic relations will be generated. + :param bool skip_ancestors: Boolean parameter to determine if the ancestor checks + should be skipped. Default is False. :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ metamorphic_relations = [] + for node_pair in combinations(dag.graph.nodes, 2): (u, v) = node_pair - # Create a ShouldNotCause relation for each pair of nodes that are not directly connected - if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): - # Case 1: U --> ... --> V - if u in nx.ancestors(dag.graph, v): + # If the graph is a causal DAG, perform the ancestor checks + if not skip_ancestors: + # Create a ShouldNotCause relation for each pair of nodes that are not directly connected + if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): + # Case 1: U --> ... --> V + if u in nx.ancestors(dag.graph, v): + adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) + metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) + + # Case 2: V --> ... --> U + elif v in nx.ancestors(dag.graph, u): + adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) + metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag)) + + # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). + else: + adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) + metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) + + # Create a ShouldCause relation for each edge (u, v) or (v, u) + elif (u, v) in dag.graph.edges: adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - - # Case 2: V --> ... --> U - elif v in nx.ancestors(dag.graph, u): + metamorphic_relations.append(ShouldCause(u, v, adj_set, dag)) + else: adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag)) + metamorphic_relations.append(ShouldCause(v, u, adj_set, dag)) - # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). - # Only make one MR since V _||_ U == U _||_ V - else: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) + # If the graph may contain loops/cycles, skip the ancestor checks + else: + # Create a ShouldNotCause relation for nodes that are not directly connected + if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): + adj_set = [] # No adjustment sets needed for cyclic graphs metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - # Create a ShouldCause relation for each edge (u, v) or (v, u) - elif (u, v) in dag.graph.edges: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldCause(u, v, adj_set, dag)) - else: - adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldCause(v, u, adj_set, dag)) + # Create a ShouldCause relation for each edge (u, v) or (v, u) + elif (u, v) in dag.graph.edges: + adj_set = [] # No adjustment sets needed for cyclic graphs + metamorphic_relations.append(ShouldCause(u, v, adj_set, dag)) + else: + adj_set = [] # No adjustment sets needed for cyclic graphs + metamorphic_relations.append(ShouldCause(v, u, adj_set, dag)) return metamorphic_relations @@ -273,10 +294,24 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: help="Specify path where tests should be saved, normally a .json file.", required=True, ) + + parser.add_argument( + "--skip_ancestors", + "-c", + action="store_true", + default=False, + help="Boolean flag to indicate if ancestors should be skipped. Default is False", + required=False, + ) + args = parser.parse_args() causal_dag = CausalDAG(args.dag_path) - relations = generate_metamorphic_relations(causal_dag) + relations = generate_metamorphic_relations(causal_dag, skip_ancestors=args.skip_ancestors) + + if args.skip_ancestors: + logger.warning("The 'skip_ancestors' variable is set to True, proceed with caution.") + tests = [ relation.to_json_stub(skip=False) for relation in relations diff --git a/tests/specification_tests/test_metamorphic_relations.py b/tests/specification_tests/test_metamorphic_relations.py index dc35e071..a34632f6 100644 --- a/tests/specification_tests/test_metamorphic_relations.py +++ b/tests/specification_tests/test_metamorphic_relations.py @@ -205,7 +205,7 @@ def test_should_cause_metamorphic_relation_missing_relationship(self): def test_all_metamorphic_relations_implied_by_dag(self): dag = CausalDAG(self.dag_dot_path) dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator - metamorphic_relations = generate_metamorphic_relations(dag) + metamorphic_relations = generate_metamorphic_relations(dag, skip_ancestors=False) should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)]