Skip to content

Commit 5d881c1

Browse files
committed
Some changes to constructive_backdoor_criterion and enumerate_minimal_adjustment_sets
1 parent 197c5d7 commit 5d881c1

File tree

2 files changed

+105
-94
lines changed

2 files changed

+105
-94
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
376376
*[set(nx.neighbors(moralised_proper_backdoor_graph, outcome)) for outcome in outcomes]
377377
) - set(outcomes)
378378

379-
neighbour_edges_to_add = list(combinations(treatment_neighbours, 2)) + list(combinations(outcome_neighbours, 2))
380-
moralised_proper_backdoor_graph.add_edges_from(neighbour_edges_to_add)
379+
moralised_proper_backdoor_graph.add_edges_from(combinations(treatment_neighbours, 2))
380+
moralised_proper_backdoor_graph.add_edges_from(combinations(outcome_neighbours, 2))
381381

382382
# 4. Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
383383
treatment_node_set = {"TREATMENT"}
@@ -596,113 +596,133 @@ def to_dot_string(self) -> str:
596596
def __str__(self):
597597
return f"Nodes: {self.nodes}\nEdges: {self.edges}"
598598

599-
class OptimisedCausalDAG(CausalDAG):
600599

600+
class OptimisedCausalDAG(CausalDAG):
601601

602602
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
603-
"""Compute minimal adjustment sets using ancestor moral graph and Takata's separator algorithm."""
603+
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
604+
and outcomes.
605+
606+
This is an implementation of the Algorithm presented in Adjustment Criteria in Causal Diagrams: An
607+
Algorithmic Perspective, Textor and Lískiewicz, 2012 and extended in Separators and adjustment sets in causal
608+
graphs: Complete criteria and an algorithmic framework, Zander et al., 2019. These works use the algorithm
609+
presented by Takata et al. in their work entitled: Space-optimal, backtracking algorithms to list the minimal
610+
vertex separators of a graph, 2013.
611+
612+
At a high-level, this algorithm proceeds as follows for a causal DAG G, set of treatments X, and set of
613+
outcomes Y):
614+
1). Transform G to a proper back-door graph G_pbd (remove the first edge from X on all proper causal paths).
615+
2). Transform G_pbd to the ancestor moral graph (G_pbd[An(X union Y)])^m.
616+
3). Apply Takata's algorithm to output all minimal X-Y separators in the graph.
617+
618+
:param treatments: A list of strings representing treatments.
619+
:param outcomes: A list of strings representing outcomes.
620+
:return: A list of strings representing the minimal adjustment set.
621+
"""
604622

605623
# Step 1: Build the proper back-door graph and its moralized ancestor graph
606-
pbd_graph = self.get_proper_backdoor_graph(treatments, outcomes)
607-
ancestor_graph = pbd_graph.get_ancestor_graph(treatments, outcomes)
608-
moral_graph = nx.moral_graph(ancestor_graph.graph)
624+
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
625+
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
626+
moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph.graph)
609627

610628
# Step 2: Add artificial TREATMENT and OUTCOME nodes
611-
moral_graph.add_edges_from([("TREATMENT", t) for t in treatments])
612-
moral_graph.add_edges_from([("OUTCOME", y) for y in outcomes])
613-
614-
# Step 3: Efficiently collect unique neighbors (excluding original nodes)
615-
treatment_neighbors = set()
616-
for t in treatments:
617-
treatment_neighbors.update(moral_graph[t])
618-
treatment_neighbors.difference_update(treatments)
619-
620-
outcome_neighbors = set()
621-
for y in outcomes:
622-
outcome_neighbors.update(moral_graph[y])
623-
outcome_neighbors.difference_update(outcomes)
624-
625-
# Step 4: Add clique edges among neighbors to preserve connectivity after node deletion
626-
moral_graph.add_edges_from(combinations(treatment_neighbors, 2))
627-
moral_graph.add_edges_from(combinations(outcome_neighbors, 2))
628-
629-
# Step 5: Find minimal separators between artificial nodes
630-
outcome_node_set = set(moral_graph["OUTCOME"]) | {"OUTCOME"}
629+
moralised_proper_backdoor_graph.add_edges_from([("TREATMENT", t) for t in treatments])
630+
moralised_proper_backdoor_graph.add_edges_from([("OUTCOME", y) for y in outcomes])
631+
632+
# Step 3: Remove treatment and outcome nodes from graph and connect neighbours
633+
treatment_neighbors = {
634+
node for t in treatments for node in moralised_proper_backdoor_graph[t] if node not in treatments
635+
}
636+
moralised_proper_backdoor_graph.add_edges_from(combinations(treatment_neighbors, 2))
637+
638+
outcome_neighbors = {
639+
node for o in outcomes for node in moralised_proper_backdoor_graph[o] if node not in outcomes
640+
}
641+
moralised_proper_backdoor_graph.add_edges_from(combinations(outcome_neighbors, 2))
642+
643+
# Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
631644
sep_candidates = list_all_min_sep_opt(
632-
moral_graph,
645+
moralised_proper_backdoor_graph,
633646
"TREATMENT",
634647
"OUTCOME",
635648
{"TREATMENT"},
636-
outcome_node_set,
649+
set(moralised_proper_backdoor_graph["OUTCOME"]) | {"OUTCOME"},
637650
)
638-
639-
# Step 6: Filter using constructive back-door criterion
640-
valid_sets = [
641-
s for s in sep_candidates
642-
if self.constructive_backdoor_criterion(pbd_graph, treatments, outcomes, s)
643-
]
644-
645-
return valid_sets
651+
return filter(
652+
lambda s: self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s),
653+
sep_candidates,
654+
)
655+
# return [
656+
# s
657+
# for s in sep_candidates
658+
# if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s)
659+
# ]
646660

647661
def constructive_backdoor_criterion(
648-
self,
649-
proper_backdoor_graph: CausalDAG,
650-
treatments: list[str],
651-
outcomes: list[str],
652-
covariates: list[str],
662+
self,
663+
proper_backdoor_graph: CausalDAG,
664+
treatments: list[str],
665+
outcomes: list[str],
666+
covariates: list[str],
653667
) -> bool:
654-
"""
655-
Optimized check for the constructive back-door criterion.
656-
"""
668+
"""A variation of Pearl's back-door criterion applied to a proper backdoor graph which enables more efficient
669+
computation of minimal adjustment sets for the effect of a set of treatments on a set of outcomes.
670+
671+
The constructive back-door criterion is satisfied for a causal DAG G, a set of treatments X, a set of outcomes
672+
Y, and a set of covariates Z, if:
673+
(1) Z is not a descendent of any variable on a proper causal path between X and Y.
674+
(2) Z d-separates X and Y in the proper back-door graph relative to X and Y.
675+
676+
Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
677+
Zander et al., 2019, Definition 4, p.16)
657678
658-
covariate_set = set(covariates)
679+
:param proper_backdoor_graph: A proper back-door graph relative to the specified treatments and outcomes.
680+
:param treatments: A list of treatment variables that appear in the proper back-door graph.
681+
:param outcomes: A list of outcome variables that appear in the proper back-door graph.
682+
:param covariates: A list of variables that appear in the proper back-door graph that we will check against
683+
the constructive back-door criterion.
684+
:return: True or False, depending on whether the set of covariates satisfies the constructive back-door
685+
criterion.
686+
"""
659687

660688
# Condition (1): Covariates must not be descendants of any node on a proper causal path
661689
proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
662-
663690
if proper_path_vars:
664691
# Collect all descendants including each proper causal path var itself
665-
all_descendants = set()
666-
for var in proper_path_vars:
667-
all_descendants.update(nx.descendants(self.graph, var))
668-
all_descendants.add(var)
692+
descendents_of_proper_casual_paths = set(proper_path_vars)
693+
descendents_of_proper_casual_paths.update(
694+
{node for var in proper_path_vars for node in nx.descendants(self.graph, var)}
695+
)
669696

670-
if covariate_set & all_descendants:
697+
if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)):
671698
# Covariates intersect with disallowed descendants — fail condition 1
672-
if logger.isEnabledFor(logging.INFO):
673-
logger.info(
674-
"Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
675-
covariates,
676-
treatments,
677-
outcomes,
678-
)
679-
return False
680-
681-
# Condition (2): Z must d-separate X and Y in the proper back-door graph
682-
if not nx.d_separated(
683-
proper_backdoor_graph.graph,
684-
set(treatments),
685-
set(outcomes),
686-
covariate_set,
687-
):
688-
if logger.isEnabledFor(logging.INFO):
689699
logger.info(
690-
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.",
700+
"Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
691701
covariates,
692702
treatments,
693703
outcomes,
694704
)
705+
return False
706+
707+
# Condition (2): Z must d-separate X and Y in the proper back-door graph
708+
if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)):
709+
logger.info(
710+
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.",
711+
covariates,
712+
treatments,
713+
outcomes,
714+
)
695715
return False
696716

697717
return True
698718

699719

700720
def list_all_min_sep_opt(
701-
graph: nx.Graph,
702-
treatment_node,
703-
outcome_node,
704-
treatment_node_set: Set,
705-
outcome_node_set: Set,
721+
graph: nx.Graph,
722+
treatment_node,
723+
outcome_node,
724+
treatment_node_set: Set,
725+
outcome_node_set: Set,
706726
) -> Generator[Set, None, None]:
707727
"""List all minimal treatment-outcome separators in an undirected graph (Takata 2013)."""
708728

@@ -755,4 +775,4 @@ def list_all_min_sep_opt(
755775
)
756776
else:
757777
# Step 8: All neighbours are in outcome set — we found a separator
758-
yield neighbour_nodes
778+
yield neighbour_nodes

tests/specification_tests/test_causal_dag.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -476,13 +476,16 @@ def test_hidden_varaible_adjustment_sets(self):
476476
def tearDown(self) -> None:
477477
shutil.rmtree(self.temp_dir_path)
478478

479+
479480
def time_it(label, func, *args, **kwargs):
480481
import time
482+
481483
start = time.time()
482484
result = func(*args, **kwargs)
483485
print(f"{label} took {time.time() - start:.6f} seconds")
484486
return result
485487

488+
486489
class TestOptimisedDAGIdentification(TestDAGIdentification):
487490
"""
488491
Test the Causal DAG identification algorithms and supporting algorithms.
@@ -495,14 +498,8 @@ def test_is_min_adjustment_for_not_min_adjustment(self):
495498

496499
opt_dag = OptimisedCausalDAG(self.dag_dot_path)
497500

498-
norm_result = time_it(
499-
"Norm",
500-
lambda: causal_dag.adjustment_set_is_minimal(xs, ys, zs)
501-
)
502-
opt_result = time_it(
503-
"Opt",
504-
lambda: opt_dag.adjustment_set_is_minimal(xs, ys, zs)
505-
)
501+
norm_result = time_it("Norm", lambda: causal_dag.adjustment_set_is_minimal(xs, ys, zs))
502+
opt_result = time_it("Opt", lambda: opt_dag.adjustment_set_is_minimal(xs, ys, zs))
506503
self.assertEqual(norm_result, opt_result)
507504

508505
def test_is_min_adjustment_for_invalid_adjustment(self):
@@ -539,7 +536,7 @@ def test_enumerate_minimal_adjustment_sets(self):
539536
causal_dag = OptimisedCausalDAG(self.dag_dot_path)
540537
xs, ys = ["X1", "X2"], ["Y"]
541538
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
542-
self.assertEqual([{"Z"}], adjustment_sets)
539+
self.assertEqual([{"Z"}], list(adjustment_sets))
543540

544541
def test_enumerate_minimal_adjustment_sets_multiple(self):
545542
"""Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
@@ -573,15 +570,9 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
573570
)
574571
xs, ys = ["X1", "X2"], ["Y"]
575572

576-
norm_adjustment_sets = time_it(
577-
"Norm",
578-
lambda: causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
579-
)
573+
norm_adjustment_sets = time_it("Norm", lambda: causal_dag.enumerate_minimal_adjustment_sets(xs, ys))
580574

581-
opt_adjustment_sets = time_it(
582-
"Opt",
583-
lambda: opt_causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
584-
)
575+
opt_adjustment_sets = time_it("Opt", lambda: opt_causal_dag.enumerate_minimal_adjustment_sets(xs, ys))
585576
set_of_opt_adjustment_sets = set(frozenset(min_separator) for min_separator in opt_adjustment_sets)
586577

587578
self.assertEqual(
@@ -634,7 +625,7 @@ def test_dag_with_non_character_nodes(self):
634625
)
635626
xs, ys = ["ba"], ["da"]
636627
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
637-
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
628+
self.assertEqual(list(adjustment_sets), [{"aa"}, {"la"}, {"va"}])
638629

639630
def tearDown(self) -> None:
640631
shutil.rmtree(self.temp_dir_path)

0 commit comments

Comments
 (0)