Skip to content

Commit 3dc9ef3

Browse files
committed
MF reviewed all optimisations, tests pass
1 parent 5d881c1 commit 3dc9ef3

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from itertools import combinations
77
from random import sample
8-
from typing import Union
8+
from typing import Union, Generator
99

1010
import networkx as nx
1111
import pydot
@@ -17,7 +17,6 @@
1717
from .scenario import Scenario
1818
from .variable import Output
1919

20-
from itertools import combinations
2120

2221
Node = Union[str, int] # Node type hint: A node is a string or an int
2322

@@ -49,11 +48,10 @@ def list_all_min_sep(
4948
# 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
5049
components_graph = graph.copy()
5150
components_graph.remove_nodes_from(close_separator_set)
52-
graph_components = nx.connected_components(components_graph)
5351

5452
# 3. Find the connected component that contains the treatment node
5553
treatment_connected_component_node_set = set()
56-
for component in graph_components:
54+
for component in nx.connected_components(components_graph):
5755
if treatment_node in component:
5856
treatment_connected_component_node_set = component
5957

@@ -599,7 +597,7 @@ def __str__(self):
599597

600598
class OptimisedCausalDAG(CausalDAG):
601599

602-
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
600+
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> Generator[set[str]]:
603601
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
604602
and outcomes.
605603
@@ -652,11 +650,6 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
652650
lambda s: self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s),
653651
sep_candidates,
654652
)
655-
# return [
656-
# s
657-
# for s in sep_candidates
658-
# if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s)
659-
# ]
660653

661654
def constructive_backdoor_criterion(
662655
self,
@@ -689,8 +682,7 @@ def constructive_backdoor_criterion(
689682
proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
690683
if proper_path_vars:
691684
# Collect all descendants including each proper causal path var itself
692-
descendents_of_proper_casual_paths = set(proper_path_vars)
693-
descendents_of_proper_casual_paths.update(
685+
descendents_of_proper_casual_paths = set(proper_path_vars).union(
694686
{node for var in proper_path_vars for node in nx.descendants(self.graph, var)}
695687
)
696688

@@ -724,54 +716,67 @@ def list_all_min_sep_opt(
724716
treatment_node_set: Set,
725717
outcome_node_set: Set,
726718
) -> Generator[Set, None, None]:
727-
"""List all minimal treatment-outcome separators in an undirected graph (Takata 2013)."""
719+
"""A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
728720
729-
# Step 1: Compute the close separator
721+
Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
722+
2013, p.5, ListMinSep procedure).
723+
724+
:param graph: An undirected graph.
725+
:param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
726+
:param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
727+
:param treatment_node_set: Set of treatment nodes.
728+
:param outcome_node_set: Set of outcome nodes.
729+
:return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected graph.
730+
"""
731+
# 1. Compute the close separator of the treatment set
730732
close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
731733

732-
# Step 2: Remove separator to identify connected components
733-
subgraph = graph.copy()
734-
subgraph.remove_nodes_from(close_separator_set)
734+
# 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
735+
components_graph = graph.copy()
736+
components_graph.remove_nodes_from(close_separator_set)
735737

736-
# Step 3: Find the component containing the treatment node
737-
treatment_component = None
738-
for component in nx.connected_components(subgraph):
738+
# 3. Find the component containing the treatment node
739+
treatment_component = set()
740+
for component in nx.connected_components(components_graph):
739741
if treatment_node in component:
740742
treatment_component = component
741743
break
742744

743-
# Step 4: Stop early if no component found or intersects outcome set
744-
if treatment_component is None or treatment_component & outcome_node_set:
745-
return
745+
# 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
746+
if treatment_component.intersection(outcome_node_set):
747+
raise ValueError(
748+
f"Connected component {treatment_component} containing the treatment node is not disjoint with "
749+
f"the outcome node set {outcome_node_set}"
750+
)
746751

747-
# Step 5: Update treatment node set to the connected component
752+
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
748753
treatment_node_set = treatment_component
749754

750-
# Step 6: Get neighbours of the treatment set
751-
neighbour_nodes = set()
752-
for node in treatment_node_set:
753-
neighbour_nodes.update(graph[node])
754-
neighbour_nodes.difference_update(treatment_node_set)
755+
# 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
756+
neighbour_nodes = {
757+
neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set
758+
}
755759

756-
# Step 7: If neighbours exist outside the outcome set, recurse
760+
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
757761
remaining = neighbour_nodes - outcome_node_set
758762
if remaining:
759-
chosen = sample(sorted(remaining), 1)[0] # Choose one deterministically (sorted) but randomly
760-
# Left branch: add to treatment set
763+
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
764+
chosen = sample(sorted(remaining), 1)
765+
# 7.2. Add this node to the treatment node set and recurse (left branch)
761766
yield from list_all_min_sep_opt(
762767
graph,
763768
treatment_node,
764769
outcome_node,
765-
treatment_node_set | {chosen},
770+
treatment_node_set.union(chosen),
766771
outcome_node_set,
767772
)
768-
# Right branch: add to outcome set
773+
# 7.3. Add this node to the outcome node set and recurse (right branch)
769774
yield from list_all_min_sep_opt(
770775
graph,
771776
treatment_node,
772777
outcome_node,
773778
treatment_node_set,
774-
outcome_node_set | {chosen},
779+
outcome_node_set.union(chosen),
775780
)
776781
else:
777782
# Step 8: All neighbours are in outcome set — we found a separator

0 commit comments

Comments
 (0)