Skip to content

Commit e168bde

Browse files
committed
Using subgraphs instead of copy seems promising
1 parent 16f868e commit e168bde

File tree

3 files changed

+236
-190
lines changed

3 files changed

+236
-190
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 1 addition & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str],
417417

418418
# Remove each variable one at a time and return false if constructive back-door criterion remains satisfied
419419
for variable in adjustment_set:
420-
smaller_adjustment_set = adjustment_set.copy()
421-
smaller_adjustment_set.remove(variable)
420+
smaller_adjustment_set = {a for a in adjustment_set if a != variable}
422421
if not smaller_adjustment_set: # Treat None as the empty set
423422
smaller_adjustment_set = set()
424423
if self.constructive_backdoor_criterion(
@@ -587,190 +586,3 @@ def to_dot_string(self) -> str:
587586

588587
def __str__(self):
589588
return f"Nodes: {self.nodes}\nEdges: {self.edges}"
590-
591-
592-
class OptimisedCausalDAG(CausalDAG):
593-
594-
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> Generator[set[str]]:
595-
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
596-
and outcomes.
597-
598-
This is an implementation of the Algorithm presented in Adjustment Criteria in Causal Diagrams: An
599-
Algorithmic Perspective, Textor and Lískiewicz, 2012 and extended in Separators and adjustment sets in causal
600-
graphs: Complete criteria and an algorithmic framework, Zander et al., 2019. These works use the algorithm
601-
presented by Takata et al. in their work entitled: Space-optimal, backtracking algorithms to list the minimal
602-
vertex separators of a graph, 2013.
603-
604-
At a high-level, this algorithm proceeds as follows for a causal DAG G, set of treatments X, and set of
605-
outcomes Y):
606-
1). Transform G to a proper back-door graph G_pbd (remove the first edge from X on all proper causal paths).
607-
2). Transform G_pbd to the ancestor moral graph (G_pbd[An(X union Y)])^m.
608-
3). Apply Takata's algorithm to output all minimal X-Y separators in the graph.
609-
610-
:param treatments: A list of strings representing treatments.
611-
:param outcomes: A list of strings representing outcomes.
612-
:return: A list of strings representing the minimal adjustment set.
613-
"""
614-
615-
# Step 1: Build the proper back-door graph and its moralized ancestor graph
616-
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
617-
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
618-
moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph.graph)
619-
620-
# Step 2: Add artificial TREATMENT and OUTCOME nodes
621-
moralised_proper_backdoor_graph.add_edges_from([("TREATMENT", t) for t in treatments])
622-
moralised_proper_backdoor_graph.add_edges_from([("OUTCOME", y) for y in outcomes])
623-
624-
# Step 3: Remove treatment and outcome nodes from graph and connect neighbours
625-
treatment_neighbors = {
626-
node for t in treatments for node in moralised_proper_backdoor_graph[t] if node not in treatments
627-
}
628-
moralised_proper_backdoor_graph.add_edges_from(combinations(treatment_neighbors, 2))
629-
630-
outcome_neighbors = {
631-
node for o in outcomes for node in moralised_proper_backdoor_graph[o] if node not in outcomes
632-
}
633-
moralised_proper_backdoor_graph.add_edges_from(combinations(outcome_neighbors, 2))
634-
635-
# Step 4: Find all minimal separators of X^m and Y^m using Takata's algorithm for listing minimal separators
636-
sep_candidates = self.list_all_min_sep_opt(
637-
moralised_proper_backdoor_graph,
638-
"TREATMENT",
639-
"OUTCOME",
640-
{"TREATMENT"},
641-
set(moralised_proper_backdoor_graph["OUTCOME"]) | {"OUTCOME"},
642-
)
643-
return filter(
644-
lambda s: self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, s),
645-
sep_candidates,
646-
)
647-
648-
def constructive_backdoor_criterion(
649-
self,
650-
proper_backdoor_graph: CausalDAG,
651-
treatments: list[str],
652-
outcomes: list[str],
653-
covariates: list[str],
654-
) -> bool:
655-
"""A variation of Pearl's back-door criterion applied to a proper backdoor graph which enables more efficient
656-
computation of minimal adjustment sets for the effect of a set of treatments on a set of outcomes.
657-
658-
The constructive back-door criterion is satisfied for a causal DAG G, a set of treatments X, a set of outcomes
659-
Y, and a set of covariates Z, if:
660-
(1) Z is not a descendent of any variable on a proper causal path between X and Y.
661-
(2) Z d-separates X and Y in the proper back-door graph relative to X and Y.
662-
663-
Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
664-
Zander et al., 2019, Definition 4, p.16)
665-
666-
:param proper_backdoor_graph: A proper back-door graph relative to the specified treatments and outcomes.
667-
:param treatments: A list of treatment variables that appear in the proper back-door graph.
668-
:param outcomes: A list of outcome variables that appear in the proper back-door graph.
669-
:param covariates: A list of variables that appear in the proper back-door graph that we will check against
670-
the constructive back-door criterion.
671-
:return: True or False, depending on whether the set of covariates satisfies the constructive back-door
672-
criterion.
673-
"""
674-
675-
# Condition (1): Covariates must not be descendants of any node on a proper causal path
676-
proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
677-
if proper_path_vars:
678-
# Collect all descendants including each proper causal path var itself
679-
descendents_of_proper_casual_paths = set(proper_path_vars).union(
680-
{node for var in proper_path_vars for node in nx.descendants(self.graph, var)}
681-
)
682-
683-
if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)):
684-
# Covariates intersect with disallowed descendants — fail condition 1
685-
logger.info(
686-
"Failed Condition 1: "
687-
"Z=%s **is** a descendant of variables on a proper causal path between X=%s and Y=%s.",
688-
covariates,
689-
treatments,
690-
outcomes,
691-
)
692-
return False
693-
694-
# Condition (2): Z must d-separate X and Y in the proper back-door graph
695-
if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)):
696-
logger.info(
697-
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.",
698-
covariates,
699-
treatments,
700-
outcomes,
701-
)
702-
return False
703-
704-
return True
705-
706-
def list_all_min_sep_opt(
707-
self,
708-
graph: nx.Graph,
709-
treatment_node: str,
710-
outcome_node: str,
711-
treatment_node_set: Set,
712-
outcome_node_set: Set,
713-
) -> Generator[Set, None, None]:
714-
"""A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
715-
716-
Reference: (Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, Ken Takata,
717-
2013, p.5, ListMinSep procedure).
718-
719-
:param graph: An undirected graph.
720-
:param treatment_node: The node corresponding to the treatment variable we wish to separate from the output.
721-
:param outcome_node: The node corresponding to the outcome variable we wish to separate from the input.
722-
:param treatment_node_set: Set of treatment nodes.
723-
:param outcome_node_set: Set of outcome nodes.
724-
:return: A generator of minimal-sized sets of variables which separate treatment and outcome in the undirected
725-
graph.
726-
"""
727-
# 1. Compute the close separator of the treatment set
728-
close_separator_set = close_separator(graph, treatment_node, outcome_node, treatment_node_set)
729-
730-
# 2. Use the close separator to separate the graph and obtain the connected components (connected sub-graphs)
731-
components_graph = graph.copy()
732-
components_graph.remove_nodes_from(close_separator_set)
733-
734-
# 3. Find the component containing the treatment node
735-
treatment_component = set()
736-
for component in nx.connected_components(components_graph):
737-
if treatment_node in component:
738-
treatment_component = component
739-
break
740-
741-
# 4. Confirm that the connected component containing the treatment node is disjoint with the outcome node set
742-
if treatment_component.intersection(outcome_node_set):
743-
return
744-
745-
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
746-
treatment_node_set = treatment_component
747-
748-
# 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
749-
neighbour_nodes = {
750-
neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set
751-
}
752-
753-
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
754-
remaining = neighbour_nodes - outcome_node_set
755-
if remaining:
756-
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
757-
chosen = sample(sorted(remaining), 1)
758-
# 7.2. Add this node to the treatment node set and recurse (left branch)
759-
yield from self.list_all_min_sep_opt(
760-
graph,
761-
treatment_node,
762-
outcome_node,
763-
treatment_node_set.union(chosen),
764-
outcome_node_set,
765-
)
766-
# 7.3. Add this node to the outcome node set and recurse (right branch)
767-
yield from self.list_all_min_sep_opt(
768-
graph,
769-
treatment_node,
770-
outcome_node,
771-
treatment_node_set,
772-
outcome_node_set.union(chosen),
773-
)
774-
else:
775-
# Step 8: All neighbours are in outcome set — we found a separator
776-
yield neighbour_nodes

0 commit comments

Comments
 (0)