@@ -376,8 +376,8 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
376
376
* [set (nx .neighbors (moralised_proper_backdoor_graph , outcome )) for outcome in outcomes ]
377
377
) - set (outcomes )
378
378
379
- neighbour_edges_to_add = list (combinations (treatment_neighbours , 2 )) + list ( combinations ( outcome_neighbours , 2 ))
380
- moralised_proper_backdoor_graph .add_edges_from (neighbour_edges_to_add )
379
+ moralised_proper_backdoor_graph . add_edges_from (combinations (treatment_neighbours , 2 ))
380
+ moralised_proper_backdoor_graph .add_edges_from (combinations ( outcome_neighbours , 2 ) )
381
381
382
382
# 4. Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
383
383
treatment_node_set = {"TREATMENT" }
@@ -596,113 +596,133 @@ def to_dot_string(self) -> str:
596
596
def __str__ (self ):
597
597
return f"Nodes: { self .nodes } \n Edges: { self .edges } "
598
598
599
- class OptimisedCausalDAG (CausalDAG ):
600
599
600
+ class OptimisedCausalDAG (CausalDAG ):
601
601
602
602
def enumerate_minimal_adjustment_sets (self , treatments : list [str ], outcomes : list [str ]) -> list [set [str ]]:
603
- """Compute minimal adjustment sets using ancestor moral graph and Takata's separator algorithm."""
603
+ """Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
604
+ and outcomes.
605
+
606
+ This is an implementation of the Algorithm presented in Adjustment Criteria in Causal Diagrams: An
607
+ Algorithmic Perspective, Textor and Lískiewicz, 2012 and extended in Separators and adjustment sets in causal
608
+ graphs: Complete criteria and an algorithmic framework, Zander et al., 2019. These works use the algorithm
609
+ presented by Takata et al. in their work entitled: Space-optimal, backtracking algorithms to list the minimal
610
+ vertex separators of a graph, 2013.
611
+
612
+ At a high-level, this algorithm proceeds as follows for a causal DAG G, set of treatments X, and set of
613
+ outcomes Y):
614
+ 1). Transform G to a proper back-door graph G_pbd (remove the first edge from X on all proper causal paths).
615
+ 2). Transform G_pbd to the ancestor moral graph (G_pbd[An(X union Y)])^m.
616
+ 3). Apply Takata's algorithm to output all minimal X-Y separators in the graph.
617
+
618
+ :param treatments: A list of strings representing treatments.
619
+ :param outcomes: A list of strings representing outcomes.
620
+ :return: A list of strings representing the minimal adjustment set.
621
+ """
604
622
605
623
# Step 1: Build the proper back-door graph and its moralized ancestor graph
606
- pbd_graph = self .get_proper_backdoor_graph (treatments , outcomes )
607
- ancestor_graph = pbd_graph .get_ancestor_graph (treatments , outcomes )
608
- moral_graph = nx .moral_graph (ancestor_graph .graph )
624
+ proper_backdoor_graph = self .get_proper_backdoor_graph (treatments , outcomes )
625
+ ancestor_proper_backdoor_graph = proper_backdoor_graph .get_ancestor_graph (treatments , outcomes )
626
+ moralised_proper_backdoor_graph = nx .moral_graph (ancestor_proper_backdoor_graph .graph )
609
627
610
628
# Step 2: Add artificial TREATMENT and OUTCOME nodes
611
- moral_graph .add_edges_from ([("TREATMENT" , t ) for t in treatments ])
612
- moral_graph .add_edges_from ([("OUTCOME" , y ) for y in outcomes ])
613
-
614
- # Step 3: Efficiently collect unique neighbors (excluding original nodes)
615
- treatment_neighbors = set ()
616
- for t in treatments :
617
- treatment_neighbors .update (moral_graph [t ])
618
- treatment_neighbors .difference_update (treatments )
619
-
620
- outcome_neighbors = set ()
621
- for y in outcomes :
622
- outcome_neighbors .update (moral_graph [y ])
623
- outcome_neighbors .difference_update (outcomes )
624
-
625
- # Step 4: Add clique edges among neighbors to preserve connectivity after node deletion
626
- moral_graph .add_edges_from (combinations (treatment_neighbors , 2 ))
627
- moral_graph .add_edges_from (combinations (outcome_neighbors , 2 ))
628
-
629
- # Step 5: Find minimal separators between artificial nodes
630
- outcome_node_set = set (moral_graph ["OUTCOME" ]) | {"OUTCOME" }
629
+ moralised_proper_backdoor_graph .add_edges_from ([("TREATMENT" , t ) for t in treatments ])
630
+ moralised_proper_backdoor_graph .add_edges_from ([("OUTCOME" , y ) for y in outcomes ])
631
+
632
+ # Step 3: Remove treatment and outcome nodes from graph and connect neighbours
633
+ treatment_neighbors = {
634
+ node for t in treatments for node in moralised_proper_backdoor_graph [t ] if node not in treatments
635
+ }
636
+ moralised_proper_backdoor_graph .add_edges_from (combinations (treatment_neighbors , 2 ))
637
+
638
+ outcome_neighbors = {
639
+ node for o in outcomes for node in moralised_proper_backdoor_graph [o ] if node not in outcomes
640
+ }
641
+ moralised_proper_backdoor_graph .add_edges_from (combinations (outcome_neighbors , 2 ))
642
+
643
+ # Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
631
644
sep_candidates = list_all_min_sep_opt (
632
- moral_graph ,
645
+ moralised_proper_backdoor_graph ,
633
646
"TREATMENT" ,
634
647
"OUTCOME" ,
635
648
{"TREATMENT" },
636
- outcome_node_set ,
649
+ set ( moralised_proper_backdoor_graph [ "OUTCOME" ]) | { "OUTCOME" } ,
637
650
)
638
-
639
- # Step 6: Filter using constructive back-door criterion
640
- valid_sets = [
641
- s for s in sep_candidates
642
- if self .constructive_backdoor_criterion (pbd_graph , treatments , outcomes , s )
643
- ]
644
-
645
- return valid_sets
651
+ return filter (
652
+ lambda s : self .constructive_backdoor_criterion (proper_backdoor_graph , treatments , outcomes , s ),
653
+ sep_candidates ,
654
+ )
655
+ # return [
656
+ # s
657
+ # for s in sep_candidates
658
+ # if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s)
659
+ # ]
646
660
647
661
def constructive_backdoor_criterion (
648
- self ,
649
- proper_backdoor_graph : CausalDAG ,
650
- treatments : list [str ],
651
- outcomes : list [str ],
652
- covariates : list [str ],
662
+ self ,
663
+ proper_backdoor_graph : CausalDAG ,
664
+ treatments : list [str ],
665
+ outcomes : list [str ],
666
+ covariates : list [str ],
653
667
) -> bool :
654
- """
655
- Optimized check for the constructive back-door criterion.
656
- """
668
+ """A variation of Pearl's back-door criterion applied to a proper backdoor graph which enables more efficient
669
+ computation of minimal adjustment sets for the effect of a set of treatments on a set of outcomes.
670
+
671
+ The constructive back-door criterion is satisfied for a causal DAG G, a set of treatments X, a set of outcomes
672
+ Y, and a set of covariates Z, if:
673
+ (1) Z is not a descendent of any variable on a proper causal path between X and Y.
674
+ (2) Z d-separates X and Y in the proper back-door graph relative to X and Y.
675
+
676
+ Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
677
+ Zander et al., 2019, Definition 4, p.16)
657
678
658
- covariate_set = set (covariates )
679
+ :param proper_backdoor_graph: A proper back-door graph relative to the specified treatments and outcomes.
680
+ :param treatments: A list of treatment variables that appear in the proper back-door graph.
681
+ :param outcomes: A list of outcome variables that appear in the proper back-door graph.
682
+ :param covariates: A list of variables that appear in the proper back-door graph that we will check against
683
+ the constructive back-door criterion.
684
+ :return: True or False, depending on whether the set of covariates satisfies the constructive back-door
685
+ criterion.
686
+ """
659
687
660
688
# Condition (1): Covariates must not be descendants of any node on a proper causal path
661
689
proper_path_vars = self .proper_causal_pathway (treatments , outcomes )
662
-
663
690
if proper_path_vars :
664
691
# Collect all descendants including each proper causal path var itself
665
- all_descendants = set ()
666
- for var in proper_path_vars :
667
- all_descendants . update ( nx .descendants (self .graph , var ))
668
- all_descendants . add ( var )
692
+ descendents_of_proper_casual_paths = set (proper_path_vars )
693
+ descendents_of_proper_casual_paths . update (
694
+ { node for var in proper_path_vars for node in nx .descendants (self .graph , var )}
695
+ )
669
696
670
- if covariate_set & all_descendants :
697
+ if not set ( covariates ). issubset ( set ( self . nodes ). difference ( descendents_of_proper_casual_paths )) :
671
698
# Covariates intersect with disallowed descendants — fail condition 1
672
- if logger .isEnabledFor (logging .INFO ):
673
- logger .info (
674
- "Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s." ,
675
- covariates ,
676
- treatments ,
677
- outcomes ,
678
- )
679
- return False
680
-
681
- # Condition (2): Z must d-separate X and Y in the proper back-door graph
682
- if not nx .d_separated (
683
- proper_backdoor_graph .graph ,
684
- set (treatments ),
685
- set (outcomes ),
686
- covariate_set ,
687
- ):
688
- if logger .isEnabledFor (logging .INFO ):
689
699
logger .info (
690
- "Failed Condition 2 : Z=%s **does not ** d-separate X=%s and Y=%s in the proper back-door graph ." ,
700
+ "Failed Condition 1 : Z=%s **is ** a descendant of variables on a proper causal path between X=%s and Y=%s." ,
691
701
covariates ,
692
702
treatments ,
693
703
outcomes ,
694
704
)
705
+ return False
706
+
707
+ # Condition (2): Z must d-separate X and Y in the proper back-door graph
708
+ if not nx .d_separated (proper_backdoor_graph .graph , set (treatments ), set (outcomes ), set (covariates )):
709
+ logger .info (
710
+ "Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph." ,
711
+ covariates ,
712
+ treatments ,
713
+ outcomes ,
714
+ )
695
715
return False
696
716
697
717
return True
698
718
699
719
700
720
def list_all_min_sep_opt (
701
- graph : nx .Graph ,
702
- treatment_node ,
703
- outcome_node ,
704
- treatment_node_set : Set ,
705
- outcome_node_set : Set ,
721
+ graph : nx .Graph ,
722
+ treatment_node ,
723
+ outcome_node ,
724
+ treatment_node_set : Set ,
725
+ outcome_node_set : Set ,
706
726
) -> Generator [Set , None , None ]:
707
727
"""List all minimal treatment-outcome separators in an undirected graph (Takata 2013)."""
708
728
@@ -755,4 +775,4 @@ def list_all_min_sep_opt(
755
775
)
756
776
else :
757
777
# Step 8: All neighbours are in outcome set — we found a separator
758
- yield neighbour_nodes
778
+ yield neighbour_nodes
0 commit comments