|
4 | 4 |
|
5 | 5 | import logging
|
6 | 6 | from itertools import combinations
|
7 |
| -from random import sample |
8 | 7 | from typing import Union, Generator, Set
|
9 | 8 |
|
10 | 9 | import networkx as nx
|
@@ -64,14 +63,12 @@ def close_separator(
|
64 | 63 | :param treatment_node_set: The set of variables containing the treatment node ({treatment_node}).
|
65 | 64 | :return: A treatment_node-outcome_node separator whose vertices are adjacent to those in treatments.
|
66 | 65 | """
|
67 |
| - treatment_neighbours = set.union(*[set(nx.neighbors(graph, treatment)) for treatment in treatment_node_set]) |
| 66 | + treatment_neighbours = {x for treatment in treatment_node_set for x in graph[treatment]} |
68 | 67 | components_graph = graph.subgraph(set(graph.nodes) - treatment_neighbours)
|
69 | 68 | graph_components = nx.connected_components(components_graph)
|
70 | 69 | for component in graph_components:
|
71 | 70 | if outcome_node in component:
|
72 |
| - neighbours_of_variables_in_component = set.union( |
73 |
| - *[set(nx.neighbors(graph, variable)) for variable in component] |
74 |
| - ) |
| 71 | + neighbours_of_variables_in_component = {x for variable in component for x in graph[variable]} |
75 | 72 | # For this algorithm, the neighbours of a node do not include the node itself
|
76 | 73 | neighbours_of_variables_in_component = neighbours_of_variables_in_component.difference(component)
|
77 | 74 | return neighbours_of_variables_in_component
|
@@ -115,37 +112,37 @@ def list_all_min_sep(
|
115 | 112 | if treatment_component.intersection(outcome_node_set):
|
116 | 113 | return
|
117 | 114 |
|
118 |
| - # 5. Update the treatment node set to the set of nodes in the connected component containing the treatment node |
119 |
| - treatment_node_set = treatment_component |
120 |
| - |
121 |
| - # 6. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves) |
| 115 | + # 5. Obtain the neighbours of the new treatment node set (this excludes the treatment nodes themselves) |
122 | 116 | neighbour_nodes = {
|
123 |
| - neighbour for node in treatment_node_set for neighbour in graph[node] if neighbour not in treatment_node_set |
| 117 | + neighbour |
| 118 | + for node in treatment_component |
| 119 | + for neighbour in graph[node] |
| 120 | + if neighbour not in treatment_component |
124 | 121 | }
|
125 | 122 |
|
126 |
| - # 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set |
| 123 | + # 6. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set |
127 | 124 | remaining = neighbour_nodes - outcome_node_set
|
128 | 125 | if remaining:
|
129 |
| - # 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set |
130 |
| - chosen = sample(sorted(remaining), 1) |
131 |
| - # 7.2. Add this node to the treatment node set and recurse (left branch) |
| 126 | + # 6.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set |
| 127 | + chosen = {next(iter(remaining))} |
| 128 | + # 6.2. Add this node to the treatment node set and recurse (left branch) |
132 | 129 | yield from self.list_all_min_sep(
|
133 | 130 | graph,
|
134 | 131 | treatment_node,
|
135 | 132 | outcome_node,
|
136 |
| - treatment_node_set.union(chosen), |
| 133 | + treatment_component.union(chosen), |
137 | 134 | outcome_node_set,
|
138 | 135 | )
|
139 |
| - # 7.3. Add this node to the outcome node set and recurse (right branch) |
| 136 | + # 6.3. Add this node to the outcome node set and recurse (right branch) |
140 | 137 | yield from self.list_all_min_sep(
|
141 | 138 | graph,
|
142 | 139 | treatment_node,
|
143 | 140 | outcome_node,
|
144 |
| - treatment_node_set, |
| 141 | + treatment_component, |
145 | 142 | outcome_node_set.union(chosen),
|
146 | 143 | )
|
147 | 144 | else:
|
148 |
| - # Step 8: All neighbours are in outcome set — we found a separator |
| 145 | + # Step 7: All neighbours are in outcome set — we found a separator |
149 | 146 | yield neighbour_nodes
|
150 | 147 |
|
151 | 148 | def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
|
@@ -185,7 +182,7 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
|
185 | 182 | :param v_of_edge: To node
|
186 | 183 | :param attr: Attributes
|
187 | 184 | """
|
188 |
| - self.add_edge(u_of_edge, v_of_edge, **attr) |
| 185 | + super().add_edge(u_of_edge, v_of_edge, **attr) |
189 | 186 | if not self.is_acyclic():
|
190 | 187 | raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
|
191 | 188 |
|
@@ -240,9 +237,8 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
|
240 | 237 | :param outcomes: A list of outcome variables to include in the ancestral graph (and their ancestors).
|
241 | 238 | :return: An ancestral graph relative to the set of variables X union Y.
|
242 | 239 | """
|
243 |
| - variables_to_keep = { |
244 |
| - ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var).union({var}) |
245 |
| - } |
| 240 | + variables_to_keep = set(treatments + outcomes) |
| 241 | + variables_to_keep.update({ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var)}) |
246 | 242 | return self.subgraph(variables_to_keep)
|
247 | 243 |
|
248 | 244 | def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
|
@@ -419,7 +415,8 @@ def constructive_backdoor_criterion(
|
419 | 415 | proper_path_vars = self.proper_causal_pathway(treatments, outcomes)
|
420 | 416 | if proper_path_vars:
|
421 | 417 | # Collect all descendants including each proper causal path var itself
|
422 |
| - descendents_of_proper_casual_paths = set(proper_path_vars).union( |
| 418 | + descendents_of_proper_casual_paths = set(proper_path_vars) |
| 419 | + descendents_of_proper_casual_paths.update( |
423 | 420 | {node for var in proper_path_vars for node in nx.descendants(self, var)}
|
424 | 421 | )
|
425 | 422 |
|
@@ -458,12 +455,12 @@ def proper_causal_pathway(self, treatments: list[str], outcomes: list[str]) -> l
|
458 | 455 | :return vars_on_proper_causal_pathway: Return a list of the variables on the proper causal pathway between
|
459 | 456 | treatments and outcomes.
|
460 | 457 | """
|
461 |
| - treatments_descendants = set.union( |
462 |
| - *[nx.descendants(self, treatment).union({treatment}) for treatment in treatments] |
463 |
| - ) |
464 |
| - treatments_descendants_without_treatments = set(treatments_descendants).difference(treatments) |
| 458 | + treatments_descendants_without_treatments = { |
| 459 | + x for treatment in treatments for x in nx.descendants(self, treatment) if x not in treatments |
| 460 | + } |
465 | 461 | backdoor_graph = self.get_backdoor_graph(set(treatments))
|
466 |
| - outcome_ancestors = set.union(*[nx.ancestors(backdoor_graph, outcome).union({outcome}) for outcome in outcomes]) |
| 462 | + outcome_ancestors = set(outcomes) |
| 463 | + outcome_ancestors.update({x for outcome in outcomes for x in nx.ancestors(backdoor_graph, outcome)}) |
467 | 464 | nodes_on_proper_causal_paths = treatments_descendants_without_treatments.intersection(outcome_ancestors)
|
468 | 465 | return nodes_on_proper_causal_paths
|
469 | 466 |
|
@@ -512,6 +509,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
|
512 | 509 | :return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
|
513 | 510 | estimate as opposed to a purely associational estimate.
|
514 | 511 | """
|
| 512 | + # Naive method to guarantee termination when we have cycles |
515 | 513 | if self.ignore_cycles:
|
516 | 514 | return set(self.predecessors(base_test_case.treatment_variable.name))
|
517 | 515 | minimal_adjustment_sets = []
|
|
0 commit comments