Skip to content

Commit edc8455

Browse files
committed
Finished optimising, tests pass
1 parent bfca023 commit edc8455

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

causal_testing/specification/optimised_causal_dag.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66
from itertools import combinations
7-
from random import sample
87
from typing import Union, Generator, Set
98

109
import networkx as nx
@@ -64,14 +63,12 @@ def close_separator(
6463
:param treatment_node_set: The set of variables containing the treatment node ({treatment_node}).
6564
:return: A treatment_node-outcome_node separator whose vertices are adjacent to those in treatments.
6665
"""
67-
treatment_neighbours = set.union(*[set(nx.neighbors(graph, treatment)) for treatment in treatment_node_set])
66+
treatment_neighbours = {x for treatment in treatment_node_set for x in graph[treatment]}
6867
components_graph = graph.subgraph(set(graph.nodes) - treatment_neighbours)
6968
graph_components = nx.connected_components(components_graph)
7069
for component in graph_components:
7170
if outcome_node in component:
72-
neighbours_of_variables_in_component = set.union(
73-
*[set(nx.neighbors(graph, variable)) for variable in component]
74-
)
71+
neighbours_of_variables_in_component = {x for variable in component for x in graph[variable]}
7572
# For this algorithm, the neighbours of a node do not include the node itself
7673
neighbours_of_variables_in_component = neighbours_of_variables_in_component.difference(component)
7774
return neighbours_of_variables_in_component
@@ -115,37 +112,37 @@ def list_all_min_sep(
115112
if treatment_component.intersection(outcome_node_set):
116113
return
117114

118-
# 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node
119-
treatment_node_set = treatment_component
120-
121-
# 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
115+
# 5. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves)
122116
neighbour_nodes = {
123-
neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set
117+
neighbour
118+
for node in treatment_component
119+
for neighbour in graph[node]
120+
if neighbour not in treatment_component
124121
}
125122

126-
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
123+
# 6. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
127124
remaining = neighbour_nodes - outcome_node_set
128125
if remaining:
129-
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
130-
chosen = sample(sorted(remaining), 1)
131-
# 7.2. Add this node to the treatment node set and recurse (left branch)
126+
# 6.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
127+
chosen = {next(iter(remaining))}
128+
# 6.2. Add this node to the treatment node set and recurse (left branch)
132129
yield from self.list_all_min_sep(
133130
graph,
134131
treatment_node,
135132
outcome_node,
136-
treatment_node_set.union(chosen),
133+
treatment_component.union(chosen),
137134
outcome_node_set,
138135
)
139-
# 7.3. Add this node to the outcome node set and recurse (right branch)
136+
# 6.3. Add this node to the outcome node set and recurse (right branch)
140137
yield from self.list_all_min_sep(
141138
graph,
142139
treatment_node,
143140
outcome_node,
144-
treatment_node_set,
141+
treatment_component,
145142
outcome_node_set.union(chosen),
146143
)
147144
else:
148-
# Step 8: All neighbours are in outcome set — we found a separator
145+
# Step 7: All neighbours are in outcome set — we found a separator
149146
yield neighbour_nodes
150147

151148
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
@@ -185,7 +182,7 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
185182
:param v_of_edge: To node
186183
:param attr: Attributes
187184
"""
188-
self.add_edge(u_of_edge, v_of_edge, **attr)
185+
super().add_edge(u_of_edge, v_of_edge, **attr)
189186
if not self.is_acyclic():
190187
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
191188

@@ -240,9 +237,8 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
240237
:param outcomes: A list of outcome variables to include in the ancestral graph (and their ancestors).
241238
:return: An ancestral graph relative to the set of variables X union Y.
242239
"""
243-
variables_to_keep = {
244-
ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var).union({var})
245-
}
240+
variables_to_keep = set(treatments + outcomes)
241+
variables_to_keep.update({ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var)})
246242
return self.subgraph(variables_to_keep)
247243

248244
def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
@@ -419,7 +415,8 @@ def constructive_backdoor_criterion(
419415
proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
420416
if proper_path_vars:
421417
# Collect all descendants including each proper causal path var itself
422-
descendents_of_proper_casual_paths = set(proper_path_vars).union(
418+
descendents_of_proper_casual_paths = set(proper_path_vars)
419+
descendents_of_proper_casual_paths.update(
423420
{node for var in proper_path_vars for node in nx.descendants(self, var)}
424421
)
425422

@@ -458,12 +455,12 @@ def proper_causal_pathway(self, treatments: list[str], outcomes: list[str]) -> l
458455
:return vars_on_proper_causal_pathway: Return a list of the variables on the proper causal pathway between
459456
treatments and outcomes.
460457
"""
461-
treatments_descendants = set.union(
462-
*[nx.descendants(self, treatment).union({treatment}) for treatment in treatments]
463-
)
464-
treatments_descendants_without_treatments = set(treatments_descendants).difference(treatments)
458+
treatments_descendants_without_treatments = {
459+
x for treatment in treatments for x in nx.descendants(self, treatment) if x not in treatments
460+
}
465461
backdoor_graph = self.get_backdoor_graph(set(treatments))
466-
outcome_ancestors = set.union(*[nx.ancestors(backdoor_graph, outcome).union({outcome}) for outcome in outcomes])
462+
outcome_ancestors = set(outcomes)
463+
outcome_ancestors.update({x for outcome in outcomes for x in nx.ancestors(backdoor_graph, outcome)})
467464
nodes_on_proper_causal_paths = treatments_descendants_without_treatments.intersection(outcome_ancestors)
468465
return nodes_on_proper_causal_paths
469466

@@ -512,6 +509,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
512509
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
513510
estimate as opposed to a purely associational estimate.
514511
"""
512+
# Naive method to guarantee termination when we have cycles
515513
if self.ignore_cycles:
516514
return set(self.predecessors(base_test_case.treatment_variable.name))
517515
minimal_adjustment_sets = []

0 commit comments

Comments
 (0)