Skip to content

Commit 197c5d7

Browse files
Experimental optimisations, proposed by ChatGPT.
1 parent b162d8f commit 197c5d7

File tree

2 files changed

+329
-1
lines changed

2 files changed

+329
-1
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
import networkx as nx
1111
import pydot
1212

13+
from typing import Generator, Set
14+
1315
from causal_testing.testing.base_test_case import BaseTestCase
1416

1517
from .scenario import Scenario
1618
from .variable import Output
1719

20+
from itertools import combinations
21+
1822
Node = Union[str, int] # Node type hint: A node is a string or an int
1923

2024
logger = logging.getLogger(__name__)
@@ -591,3 +595,164 @@ def to_dot_string(self) -> str:
591595

592596
def __str__(self):
593597
return f"Nodes: {self.nodes}\nEdges: {self.edges}"
598+
599+
class OptimisedCausalDAG(CausalDAG):
600+
601+
602+
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
603+
"""Compute minimal adjustment sets using ancestor moral graph and Takata's separator algorithm."""
604+
605+
# Step 1: Build the proper back-door graph and its moralized ancestor graph
606+
pbd_graph = self.get_proper_backdoor_graph(treatments, outcomes)
607+
ancestor_graph = pbd_graph.get_ancestor_graph(treatments, outcomes)
608+
moral_graph = nx.moral_graph(ancestor_graph.graph)
609+
610+
# Step 2: Add artificial TREATMENT and OUTCOME nodes
611+
moral_graph.add_edges_from([("TREATMENT", t) for t in treatments])
612+
moral_graph.add_edges_from([("OUTCOME", y) for y in outcomes])
613+
614+
# Step 3: Efficiently collect unique neighbors (excluding original nodes)
615+
treatment_neighbors = set()
616+
for t in treatments:
617+
treatment_neighbors.update(moral_graph[t])
618+
treatment_neighbors.difference_update(treatments)
619+
620+
outcome_neighbors = set()
621+
for y in outcomes:
622+
outcome_neighbors.update(moral_graph[y])
623+
outcome_neighbors.difference_update(outcomes)
624+
625+
# Step 4: Add clique edges among neighbors to preserve connectivity after node deletion
626+
moral_graph.add_edges_from(combinations(treatment_neighbors, 2))
627+
moral_graph.add_edges_from(combinations(outcome_neighbors, 2))
628+
629+
# Step 5: Find minimal separators between artificial nodes
630+
outcome_node_set = set(moral_graph["OUTCOME"]) | {"OUTCOME"}
631+
sep_candidates = list_all_min_sep_opt(
632+
moral_graph,
633+
"TREATMENT",
634+
"OUTCOME",
635+
{"TREATMENT"},
636+
outcome_node_set,
637+
)
638+
639+
# Step 6: Filter using constructive back-door criterion
640+
valid_sets = [
641+
s for s in sep_candidates
642+
if self.constructive_backdoor_criterion(pbd_graph, treatments, outcomes, s)
643+
]
644+
645+
return valid_sets
646+
647+
def constructive_backdoor_criterion(
648+
self,
649+
proper_backdoor_graph: CausalDAG,
650+
treatments: list[str],
651+
outcomes: list[str],
652+
covariates: list[str],
653+
) -> bool:
654+
"""
655+
Optimized check for the constructive back-door criterion.
656+
"""
657+
658+
covariate_set = set(covariates)
659+
660+
# Condition (1): Covariates must not be descendants of any node on a proper causal path
661+
proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
662+
663+
if proper_path_vars:
664+
# Collect all descendants including each proper causal path var itself
665+
all_descendants = set()
666+
for var in proper_path_vars:
667+
all_descendants.update(nx.descendants(self.graph, var))
668+
all_descendants.add(var)
669+
670+
if covariate_set & all_descendants:
671+
# Covariates intersect with disallowed descendants — fail condition 1
672+
if logger.isEnabledFor(logging.INFO):
673+
logger.info(
674+
"Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
675+
covariates,
676+
treatments,
677+
outcomes,
678+
)
679+
return False
680+
681+
# Condition (2): Z must d-separate X and Y in the proper back-door graph
682+
if not nx.d_separated(
683+
proper_backdoor_graph.graph,
684+
set(treatments),
685+
set(outcomes),
686+
covariate_set,
687+
):
688+
if logger.isEnabledFor(logging.INFO):
689+
logger.info(
690+
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.",
691+
covariates,
692+
treatments,
693+
outcomes,
694+
)
695+
return False
696+
697+
return True
698+
699+
700+
def list_all_min_sep_opt(
701+
graph: nx.Graph,
702+
treatment_node,
703+
outcome_node,
704+
treatment_node_set: Set,
705+
outcome_node_set: Set,
706+
) -> Generator[Set, None, None]:
707+
"""List all minimal treatment-outcome separators in an undirected graph (Takata 2013)."""
708+
709+
# Step 1: Compute the close separator
710+
close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
711+
712+
# Step 2: Remove separator to identify connected components
713+
subgraph = graph.copy()
714+
subgraph.remove_nodes_from(close_separator_set)
715+
716+
# Step 3: Find the component containing the treatment node
717+
treatment_component = None
718+
for component in nx.connected_components(subgraph):
719+
if treatment_node in component:
720+
treatment_component = component
721+
break
722+
723+
# Step 4: Stop early if no component found or intersects outcome set
724+
if treatment_component is None or treatment_component & outcome_node_set:
725+
return
726+
727+
# Step 5: Update treatment node set to the connected component
728+
treatment_node_set = treatment_component
729+
730+
# Step 6: Get neighbours of the treatment set
731+
neighbour_nodes = set()
732+
for node in treatment_node_set:
733+
neighbour_nodes.update(graph[node])
734+
neighbour_nodes.difference_update(treatment_node_set)
735+
736+
# Step 7: If neighbours exist outside the outcome set, recurse
737+
remaining = neighbour_nodes - outcome_node_set
738+
if remaining:
739+
chosen = sample(sorted(remaining), 1)[0] # Choose one deterministically (sorted) but randomly
740+
# Left branch: add to treatment set
741+
yield from list_all_min_sep_opt(
742+
graph,
743+
treatment_node,
744+
outcome_node,
745+
treatment_node_set | {chosen},
746+
outcome_node_set,
747+
)
748+
# Right branch: add to outcome set
749+
yield from list_all_min_sep_opt(
750+
graph,
751+
treatment_node,
752+
outcome_node,
753+
treatment_node_set,
754+
outcome_node_set | {chosen},
755+
)
756+
else:
757+
# Step 8: All neighbours are in outcome set — we found a separator
758+
yield neighbour_nodes

tests/specification_tests/test_causal_dag.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import shutil, tempfile
44
import networkx as nx
5-
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep
5+
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep, OptimisedCausalDAG
66
from causal_testing.specification.scenario import Scenario
77
from causal_testing.specification.variable import Input, Output
88
from causal_testing.testing.base_test_case import BaseTestCase
@@ -475,3 +475,166 @@ def test_hidden_varaible_adjustment_sets(self):
475475

476476
def tearDown(self) -> None:
477477
shutil.rmtree(self.temp_dir_path)
478+
479+
def time_it(label, func, *args, **kwargs):
480+
import time
481+
start = time.time()
482+
result = func(*args, **kwargs)
483+
print(f"{label} took {time.time() - start:.6f} seconds")
484+
return result
485+
486+
class TestOptimisedDAGIdentification(TestDAGIdentification):
487+
"""
488+
Test the Causal DAG identification algorithms and supporting algorithms.
489+
"""
490+
491+
def test_is_min_adjustment_for_not_min_adjustment(self):
492+
"""Test whether is_min_adjustment can correctly test whether the minimum adjustment set is not minimal."""
493+
causal_dag = CausalDAG(self.dag_dot_path)
494+
xs, ys, zs = ["X1", "X2"], ["Y"], {"Z", "V"}
495+
496+
opt_dag = OptimisedCausalDAG(self.dag_dot_path)
497+
498+
norm_result = time_it(
499+
"Norm",
500+
lambda: causal_dag.adjustment_set_is_minimal(xs, ys, zs)
501+
)
502+
opt_result = time_it(
503+
"Opt",
504+
lambda: opt_dag.adjustment_set_is_minimal(xs, ys, zs)
505+
)
506+
self.assertEqual(norm_result, opt_result)
507+
508+
def test_is_min_adjustment_for_invalid_adjustment(self):
509+
"""Test whether is min_adjustment can correctly identify that the minimum adjustment set is invalid."""
510+
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
511+
xs, ys, zs = ["X1", "X2"], ["Y"], set()
512+
self.assertRaises(ValueError, causal_dag.adjustment_set_is_minimal, xs, ys, zs)
513+
514+
def test_get_ancestor_graph_of_causal_dag(self):
515+
"""Test whether get_ancestor_graph converts a CausalDAG to the correct ancestor graph."""
516+
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
517+
xs, ys = ["X1", "X2"], ["Y"]
518+
ancestor_graph = causal_dag.get_ancestor_graph(xs, ys)
519+
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
520+
self.assertEqual(
521+
list(ancestor_graph.edges),
522+
[("X1", "X2"), ("X2", "D1"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
523+
)
524+
525+
def test_get_ancestor_graph_of_proper_backdoor_graph(self):
526+
"""Test whether get_ancestor_graph converts a CausalDAG to the correct proper back-door graph."""
527+
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
528+
xs, ys = ["X1", "X2"], ["Y"]
529+
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
530+
ancestor_graph = proper_backdoor_graph.get_ancestor_graph(xs, ys)
531+
self.assertEqual(list(ancestor_graph.nodes), ["X1", "X2", "D1", "Y", "Z"])
532+
self.assertEqual(
533+
list(ancestor_graph.edges),
534+
[("X1", "X2"), ("D1", "Y"), ("Z", "X2"), ("Z", "Y")],
535+
)
536+
537+
def test_enumerate_minimal_adjustment_sets(self):
538+
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
539+
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
540+
xs, ys = ["X1", "X2"], ["Y"]
541+
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
542+
self.assertEqual([{"Z"}], adjustment_sets)
543+
544+
def test_enumerate_minimal_adjustment_sets_multiple(self):
545+
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
546+
causal_dag = CausalDAG()
547+
causal_dag.graph.add_edges_from(
548+
[
549+
("X1", "X2"),
550+
("X2", "V"),
551+
("Z1", "X2"),
552+
("Z1", "Z2"),
553+
("Z2", "Z3"),
554+
("Z3", "Y"),
555+
("D1", "Y"),
556+
("D1", "D2"),
557+
("Y", "D3"),
558+
]
559+
)
560+
opt_causal_dag = CausalDAG()
561+
opt_causal_dag.graph.add_edges_from(
562+
[
563+
("X1", "X2"),
564+
("X2", "V"),
565+
("Z1", "X2"),
566+
("Z1", "Z2"),
567+
("Z2", "Z3"),
568+
("Z3", "Y"),
569+
("D1", "Y"),
570+
("D1", "D2"),
571+
("Y", "D3"),
572+
]
573+
)
574+
xs, ys = ["X1", "X2"], ["Y"]
575+
576+
norm_adjustment_sets = time_it(
577+
"Norm",
578+
lambda: causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
579+
)
580+
581+
opt_adjustment_sets = time_it(
582+
"Opt",
583+
lambda: opt_causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
584+
)
585+
set_of_opt_adjustment_sets = set(frozenset(min_separator) for min_separator in opt_adjustment_sets)
586+
587+
self.assertEqual(
588+
{frozenset({"Z1"}), frozenset({"Z2"}), frozenset({"Z3"})},
589+
set_of_opt_adjustment_sets,
590+
)
591+
592+
def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
593+
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
594+
causal_dag = OptimisedCausalDAG()
595+
causal_dag.graph.add_edges_from(
596+
[
597+
("X1", "X2"),
598+
("X2", "V"),
599+
("Z1", "X2"),
600+
("Z1", "Z2"),
601+
("Z2", "Z3"),
602+
("Z3", "Y"),
603+
("D1", "Y"),
604+
("D1", "D2"),
605+
("Y", "D3"),
606+
("Z4", "X1"),
607+
("Z4", "Y"),
608+
("X2", "D1"),
609+
]
610+
)
611+
xs, ys = ["X1", "X2"], ["Y"]
612+
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
613+
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
614+
self.assertEqual(
615+
{frozenset({"Z1", "Z4"}), frozenset({"Z2", "Z4"}), frozenset({"Z3", "Z4"})},
616+
set_of_adjustment_sets,
617+
)
618+
619+
def test_dag_with_non_character_nodes(self):
620+
"""Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
621+
causal_dag = OptimisedCausalDAG()
622+
causal_dag.graph.add_edges_from(
623+
[
624+
("va", "ba"),
625+
("ba", "ia"),
626+
("ba", "da"),
627+
("ba", "ra"),
628+
("la", "va"),
629+
("la", "aa"),
630+
("aa", "ia"),
631+
("aa", "da"),
632+
("aa", "ra"),
633+
]
634+
)
635+
xs, ys = ["ba"], ["da"]
636+
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
637+
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
638+
639+
def tearDown(self) -> None:
640+
shutil.rmtree(self.temp_dir_path)

0 commit comments

Comments
 (0)