5
5
import logging
6
6
from itertools import combinations
7
7
from random import sample
8
- from typing import Union , Generator
8
+ from typing import Union , Generator , Set
9
9
10
10
import networkx as nx
11
11
import pydot
12
12
13
- from typing import Generator , Set
14
-
15
13
from causal_testing .testing .base_test_case import BaseTestCase
16
14
17
15
from .scenario import Scenario
@@ -136,13 +134,7 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
136
134
super ().__init__ (** attr )
137
135
self .ignore_cycles = ignore_cycles
138
136
if dot_path :
139
- with open (dot_path , "r" , encoding = "utf-8" ) as file :
140
- dot_content = file .read ().replace ("\n " , "" )
141
- # Previously, we used pydot_graph_from_file() to read in the dot_path directly, however,
142
- # this method does not currently have a way of removing spurious nodes.
143
- # Workaround: Read in the file using open(), remove new lines, and then create the pydot_graph.
144
- pydot_graph = pydot .graph_from_dot_data (dot_content )
145
- self .graph = nx .DiGraph (nx .drawing .nx_pydot .from_pydot (pydot_graph [0 ]))
137
+ self .graph = nx .DiGraph (nx .nx_pydot .read_dot (dot_path ))
146
138
else :
147
139
self .graph = nx .DiGraph ()
148
140
@@ -576,10 +568,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
576
568
if scenario is not None :
577
569
minimal_adjustment_sets = self .remove_hidden_adjustment_sets (minimal_adjustment_sets , scenario )
578
570
579
- if len (minimal_adjustment_sets ) == 0 :
580
- return set ()
581
-
582
- minimal_adjustment_set = min (minimal_adjustment_sets , key = len )
571
+ minimal_adjustment_set = min (minimal_adjustment_sets , key = len , default = set ())
583
572
return set (minimal_adjustment_set )
584
573
585
574
def to_dot_string (self ) -> str :
@@ -639,7 +628,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
639
628
moralised_proper_backdoor_graph .add_edges_from (combinations (outcome_neighbors , 2 ))
640
629
641
630
# Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
642
- sep_candidates = list_all_min_sep_opt (
631
+ sep_candidates = self . list_all_min_sep_opt (
643
632
moralised_proper_backdoor_graph ,
644
633
"TREATMENT" ,
645
634
"OUTCOME" ,
@@ -689,7 +678,8 @@ def constructive_backdoor_criterion(
689
678
if not set (covariates ).issubset (set (self .nodes ).difference (descendents_of_proper_casual_paths )):
690
679
# Covariates intersect with disallowed descendants — fail condition 1
691
680
logger .info (
692
- "Failed Condition 1: Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s." ,
681
+ "Failed Condition 1: "
682
+ "Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s." ,
693
683
covariates ,
694
684
treatments ,
695
685
outcomes ,
@@ -708,76 +698,74 @@ def constructive_backdoor_criterion(
708
698
709
699
return True
710
700
701
+ def list_all_min_sep_opt (
702
+ self ,
703
+ graph : nx .Graph ,
704
+ treatment_node : str ,
705
+ outcome_node : str ,
706
+ treatment_node_set : Set ,
707
+ outcome_node_set : Set ,
708
+ ) -> Generator [Set , None , None ]:
709
+ """A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
710
+
711
+ Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
712
+ 2013, p.5, ListMinSep procedure).
713
+
714
+ :param graph: An undirected graph.
715
+ :param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
716
+ :param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
717
+ :param treatment_node_set: Set of treatment nodes.
718
+ :param outcome_node_set: Set of outcome nodes.
719
+ :return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected
720
+ graph.
721
+ """
722
+ # 1. Compute the close separator of the treatment set
723
+ close_separator_set = close_separator (graph , treatment_node , outcome_node , treatment_node_set )
711
724
712
- def list_all_min_sep_opt (
713
- graph : nx .Graph ,
714
- treatment_node ,
715
- outcome_node ,
716
- treatment_node_set : Set ,
717
- outcome_node_set : Set ,
718
- ) -> Generator [Set , None , None ]:
719
- """A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
720
-
721
- Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
722
- 2013, p.5, ListMinSep procedure).
725
+ # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
726
+ components_graph = graph .copy ()
727
+ components_graph .remove_nodes_from (close_separator_set )
723
728
724
- :param graph: An undirected graph.
725
- :param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
726
- :param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
727
- :param treatment_node_set: Set of treatment nodes.
728
- :param outcome_node_set: Set of outcome nodes.
729
- :return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected graph.
730
- """
731
- # 1. Compute the close separator of the treatment set
732
- close_separator_set = close_separator (graph , treatment_node , outcome_node , treatment_node_set )
729
+ # 3. Find the component containing the treatment node
730
+ treatment_component = set ()
731
+ for component in nx .connected_components (components_graph ):
732
+ if treatment_node in component :
733
+ treatment_component = component
734
+ break
733
735
734
- # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
735
- components_graph = graph . copy ()
736
- components_graph . remove_nodes_from ( close_separator_set )
736
+ # 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
737
+ if treatment_component . intersection ( outcome_node_set ):
738
+ return
737
739
738
- # 3. Find the component containing the treatment node
739
- treatment_component = set ()
740
- for component in nx .connected_components (components_graph ):
741
- if treatment_node in component :
742
- treatment_component = component
743
- break
740
+ # 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
741
+ treatment_node_set = treatment_component
744
742
745
- # 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
746
- if treatment_component .intersection (outcome_node_set ):
747
- raise ValueError (
748
- f"Connected component { treatment_component } containing the treatment node is not disjoint with "
749
- f"the outcome node set { outcome_node_set } "
750
- )
743
+ # 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
744
+ neighbour_nodes = {
745
+ neighbour for node in treatment_node_set for neighbour in graph [node ] if neighbour not in treatment_node_set
746
+ }
751
747
752
- # 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
753
- treatment_node_set = treatment_component
754
-
755
- # 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
756
- neighbour_nodes = {
757
- neighbour for node in treatment_node_set for neighbour in graph [node ] if neighbour not in treatment_node_set
758
- }
759
-
760
- # 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
761
- remaining = neighbour_nodes - outcome_node_set
762
- if remaining :
763
- # 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
764
- chosen = sample (sorted (remaining ), 1 )
765
- # 7.2. Add this node to the treatment node set and recurse (left branch)
766
- yield from list_all_min_sep_opt (
767
- graph ,
768
- treatment_node ,
769
- outcome_node ,
770
- treatment_node_set .union (chosen ),
771
- outcome_node_set ,
772
- )
773
- # 7.3. Add this node to the outcome node set and recurse (right branch)
774
- yield from list_all_min_sep_opt (
775
- graph ,
776
- treatment_node ,
777
- outcome_node ,
778
- treatment_node_set ,
779
- outcome_node_set .union (chosen ),
780
- )
781
- else :
782
- # Step 8: All neighbours are in outcome set — we found a separator
783
- yield neighbour_nodes
748
+ # 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
749
+ remaining = neighbour_nodes - outcome_node_set
750
+ if remaining :
751
+ # 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
752
+ chosen = sample (sorted (remaining ), 1 )
753
+ # 7.2. Add this node to the treatment node set and recurse (left branch)
754
+ yield from self .list_all_min_sep_opt (
755
+ graph ,
756
+ treatment_node ,
757
+ outcome_node ,
758
+ treatment_node_set .union (chosen ),
759
+ outcome_node_set ,
760
+ )
761
+ # 7.3. Add this node to the outcome node set and recurse (right branch)
762
+ yield from self .list_all_min_sep_opt (
763
+ graph ,
764
+ treatment_node ,
765
+ outcome_node ,
766
+ treatment_node_set ,
767
+ outcome_node_set .union (chosen ),
768
+ )
769
+ else :
770
+ # Step 8: All neighbours are in outcome set — we found a separator
771
+ yield neighbour_nodes
0 commit comments