|
5 | 5 | import logging
|
6 | 6 | from itertools import combinations
|
7 | 7 | from random import sample
|
8 |
| -from typing import Union |
| 8 | +from typing import Union, Generator |
9 | 9 |
|
10 | 10 | import networkx as nx
|
11 | 11 | import pydot
|
|
17 | 17 | from .scenario import Scenario
|
18 | 18 | from .variable import Output
|
19 | 19 |
|
20 |
| -from itertools import combinations |
21 | 20 |
|
22 | 21 | Node = Union[str, int] # Node type hint: A node is a string or an int
|
23 | 22 |
|
@@ -49,11 +48,10 @@ def list_all_min_sep(
|
49 | 48 | # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
|
50 | 49 | components_graph = graph.copy()
|
51 | 50 | components_graph.remove_nodes_from(close_separator_set)
|
52 |
| - graph_components = nx.connected_components(components_graph) |
53 | 51 |
|
54 | 52 | # 3. Find the connected component that contains the treatment node
|
55 | 53 | treatment_connected_component_node_set = set()
|
56 |
| - for component in graph_components: |
| 54 | + for component in nx.connected_components(components_graph): |
57 | 55 | if treatment_node in component:
|
58 | 56 | treatment_connected_component_node_set = component
|
59 | 57 |
|
@@ -599,7 +597,7 @@ def __str__(self):
|
599 | 597 |
|
600 | 598 | class OptimisedCausalDAG(CausalDAG):
|
601 | 599 |
|
602 |
| - def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: |
| 600 | + def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> Generator[set[str]]: |
603 | 601 | """Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
|
604 | 602 | and outcomes.
|
605 | 603 |
|
@@ -652,11 +650,6 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
|
652 | 650 | lambda s: self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s),
|
653 | 651 | sep_candidates,
|
654 | 652 | )
|
655 |
| - # return [ |
656 |
| - # s |
657 |
| - # for s in sep_candidates |
658 |
| - # if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s) |
659 |
| - # ] |
660 | 653 |
|
661 | 654 | def constructive_backdoor_criterion(
|
662 | 655 | self,
|
@@ -689,8 +682,7 @@ def constructive_backdoor_criterion(
|
689 | 682 | proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
|
690 | 683 | if proper_path_vars:
|
691 | 684 | # Collect all descendants including each proper causal path var itself
|
692 |
| - descendents_of_proper_casual_paths = set(proper_path_vars) |
693 |
| - descendents_of_proper_casual_paths.update( |
| 685 | + descendents_of_proper_casual_paths = set(proper_path_vars).union( |
694 | 686 | {node for var in proper_path_vars for node in nx.descendants(self.graph, var)}
|
695 | 687 | )
|
696 | 688 |
|
@@ -724,54 +716,67 @@ def list_all_min_sep_opt(
|
724 | 716 | treatment_node_set: Set,
|
725 | 717 | outcome_node_set: Set,
|
726 | 718 | ) -> Generator[Set, None, None]:
|
727 |
| - """List all minimal treatment-outcome separators in an undirected graph (Takata 2013).""" |
| 719 | + """A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph. |
728 | 720 |
|
729 |
| - # Step 1: Compute the close separator |
| 721 | + Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata, |
| 722 | + 2013, p.5, ListMinSep procedure). |
| 723 | +
|
| 724 | + :param graph: An undirected graph. |
| 725 | + :param treatment_node: The node corresponding to the treatment variable we wish to separate from the output. |
| 726 | + :param outcome_node: The node corresponding to the outcome variable we wish to separate from the input. |
| 727 | + :param treatment_node_set: Set of treatment nodes. |
| 728 | + :param outcome_node_set: Set of outcome nodes. |
| 729 | + :return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected graph. |
| 730 | + """ |
| 731 | + # 1. Compute the close separator of the treatment set |
730 | 732 | close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
|
731 | 733 |
|
732 |
| - # Step 2: Remove separator to identify connected components |
733 |
| - subgraph = graph.copy() |
734 |
| - subgraph.remove_nodes_from(close_separator_set) |
| 734 | + # 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs) |
| 735 | + components_graph = graph.copy() |
| 736 | + components_graph.remove_nodes_from(close_separator_set) |
735 | 737 |
|
736 |
| - # Step 3: Find the component containing the treatment node |
737 |
| - treatment_component = None |
738 |
| - for component in nx.connected_components(subgraph): |
| 738 | + # 3. Find the component containing the treatment node |
| 739 | + treatment_component = set() |
| 740 | + for component in nx.connected_components(components_graph): |
739 | 741 | if treatment_node in component:
|
740 | 742 | treatment_component = component
|
741 | 743 | break
|
742 | 744 |
|
743 |
| - # Step 4: Stop early if no component found or intersects outcome set |
744 |
| - if treatment_component is None or treatment_component & outcome_node_set: |
745 |
| - return |
| 745 | + # 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set |
| 746 | + if treatment_component.intersection(outcome_node_set): |
| 747 | + raise ValueError( |
| 748 | + f"Connected component {treatment_component} containing the treatment node is not disjoint with " |
| 749 | + f"the outcome node set {outcome_node_set}" |
| 750 | + ) |
746 | 751 |
|
747 |
| - # Step 5: Update treatment node set to the connected component |
| 752 | + # 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node |
748 | 753 | treatment_node_set = treatment_component
|
749 | 754 |
|
750 |
| - # Step 6: Get neighbours of the treatment set |
751 |
| - neighbour_nodes = set() |
752 |
| - for node in treatment_node_set: |
753 |
| - neighbour_nodes.update(graph[node]) |
754 |
| - neighbour_nodes.difference_update(treatment_node_set) |
| 755 | + # 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves) |
| 756 | + neighbour_nodes = { |
| 757 | + neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set |
| 758 | + } |
755 | 759 |
|
756 |
| - # Step 7: If neighbours exist outside the outcome set, recurse |
| 760 | + # 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set |
757 | 761 | remaining = neighbour_nodes - outcome_node_set
|
758 | 762 | if remaining:
|
759 |
| - chosen = sample(sorted(remaining), 1)[0] # Choose one deterministically (sorted) but randomly |
760 |
| - # Left branch: add to treatment set |
| 763 | + # 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set |
| 764 | + chosen = sample(sorted(remaining), 1) |
| 765 | + # 7.2. Add this node to the treatment node set and recurse (left branch) |
761 | 766 | yield from list_all_min_sep_opt(
|
762 | 767 | graph,
|
763 | 768 | treatment_node,
|
764 | 769 | outcome_node,
|
765 |
| - treatment_node_set | {chosen}, |
| 770 | + treatment_node_set.union(chosen), |
766 | 771 | outcome_node_set,
|
767 | 772 | )
|
768 |
| - # Right branch: add to outcome set |
| 773 | + # 7.3. Add this node to the outcome node set and recurse (right branch) |
769 | 774 | yield from list_all_min_sep_opt(
|
770 | 775 | graph,
|
771 | 776 | treatment_node,
|
772 | 777 | outcome_node,
|
773 | 778 | treatment_node_set,
|
774 |
| - outcome_node_set | {chosen}, |
| 779 | + outcome_node_set.union(chosen), |
775 | 780 | )
|
776 | 781 | else:
|
777 | 782 | # Step 8: All neighbours are in outcome set — we found a separator
|
|
0 commit comments