Skip to content

Commit 5d53f9a

Browse files
committed
Removed all instances of copy
1 parent 4218038 commit 5d53f9a

File tree

1 file changed

+39
-72
lines changed

1 file changed

+39
-72
lines changed

causal_testing/specification/optimised_causal_dag.py

Lines changed: 39 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,17 @@ class CausalDAG(nx.DiGraph):
2626
ensures it is acyclic. A CausalDAG must be specified as a dot file.
2727
"""
2828

29-
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
29+
def __init__(self, file_path: str = None, ignore_cycles: bool = False, **attr):
3030
super().__init__(**attr)
3131
self.ignore_cycles = ignore_cycles
32-
if dot_path:
33-
if dot_path.endswith(".dot"):
34-
self.graph = nx.DiGraph(nx.nx_pydot.read_dot(dot_path))
35-
elif dot_path.endswith(".xml"):
36-
self.graph = nx.graphml.read_graphml(dot_path)
32+
if file_path:
33+
if file_path.endswith(".dot"):
34+
graph = nx.DiGraph(nx.nx_pydot.read_dot(file_path))
35+
elif file_path.endswith(".xml"):
36+
graph = nx.graphml.read_graphml(file_path)
3737
else:
38-
raise ValueError(f"Unsupported file extension {dot_path}. We only support .dot and .xml files.")
39-
else:
40-
self.graph = nx.DiGraph()
38+
raise ValueError(f"Unsupported file extension {file_path}. We only support .dot and .xml files.")
39+
self.update(graph)
4140

4241
if not self.is_acyclic():
4342
if ignore_cycles:
@@ -47,22 +46,6 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
4746
else:
4847
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
4948

50-
@property
51-
def nodes(self) -> list:
52-
"""
53-
Get the nodes of the DAG.
54-
:returns: The nodes of the DAG.
55-
"""
56-
return self.graph.nodes
57-
58-
@property
59-
def edges(self) -> list:
60-
"""
61-
Get the edges of the DAG.
62-
:returns: The edges of the DAG.
63-
"""
64-
return self.graph.edges
65-
6649
def close_separator(
6750
self, graph: nx.Graph, treatment_node: Node, outcome_node: Node, treatment_node_set: set[Node]
6851
) -> set[Node]:
@@ -173,11 +156,11 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
173156
:return Boolean True if the three IV assumptions hold.
174157
"""
175158
# (i) Instrument is associated with treatment
176-
if nx.d_separated(self.graph, {instrument}, {treatment}, set()):
159+
if nx.d_separated(self, {instrument}, {treatment}, set()):
177160
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
178161

179162
# (ii) Instrument does not affect outcome except through its potential effect on treatment
180-
if not all((treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome))):
163+
if not all((treatment in path for path in nx.all_simple_paths(self, source=instrument, target=outcome))):
181164
raise ValueError(
182165
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
183166
)
@@ -186,11 +169,9 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
186169

187170
for cause in self.nodes:
188171
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
189-
outcome_paths = (
190-
list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else []
191-
)
172+
outcome_paths = list(nx.all_simple_paths(self, source=cause, target=outcome)) if cause != outcome else []
192173
instrument_paths = (
193-
list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else []
174+
list(nx.all_simple_paths(self, source=cause, target=instrument)) if cause != instrument else []
194175
)
195176
if len(instrument_paths) > 0 and len(outcome_paths) > 0:
196177
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
@@ -204,15 +185,15 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
204185
:param v_of_edge: To node
205186
:param attr: Attributes
206187
"""
207-
self.graph.add_edge(u_of_edge, v_of_edge, **attr)
188+
self.add_edge(u_of_edge, v_of_edge, **attr)
208189
if not self.is_acyclic():
209190
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
210191

211192
def cycle_nodes(self) -> list:
212193
"""Get the nodes involved in any cycles.
213194
:return: A list containing all nodes involved in a cycle.
214195
"""
215-
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]
196+
return [node for cycle in nx.simple_cycles(self) for node in cycle]
216197

217198
def is_acyclic(self) -> bool:
218199
"""Checks if the graph is acyclic.
@@ -235,16 +216,14 @@ def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str])
235216
:param outcomes: A list of outcomes.
236217
:return: A CausalDAG corresponding to the proper back-door graph.
237218
"""
238-
for var in treatments + outcomes:
239-
if var not in self.nodes:
240-
raise IndexError(f"{var} not a node in Causal DAG.\nValid nodes are{self.nodes}.")
241-
242-
proper_backdoor_graph = self.copy()
243-
nodes_on_proper_causal_path = proper_backdoor_graph.proper_causal_pathway(treatments, outcomes)
244-
edges_to_remove = [
245-
(u, v) for (u, v) in proper_backdoor_graph.graph.out_edges(treatments) if v in nodes_on_proper_causal_path
246-
]
247-
proper_backdoor_graph.graph.remove_edges_from(edges_to_remove)
219+
assert set(treatments + outcomes).issubset(
220+
set(self.nodes)
221+
), f"Nodes {set(treatments+outcomes).difference(set(self.nodes))} not in causal DAG"
222+
223+
nodes_on_proper_causal_path = self.proper_causal_pathway(treatments, outcomes)
224+
proper_backdoor_graph = CausalDAG()
225+
edges_to_remove = {(u, v) for (u, v) in self.out_edges(treatments) if v in nodes_on_proper_causal_path}
226+
proper_backdoor_graph.add_edges_from(e for e in self.edges if e not in edges_to_remove)
248227
return proper_backdoor_graph
249228

250229
def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
@@ -261,17 +240,10 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
261240
:param outcomes: A list of outcome variables to include in the ancestral graph (and their ancestors).
262241
:return: An ancestral graph relative to the set of variables X union Y.
263242
"""
264-
ancestor_graph = self.copy()
265-
treatment_ancestors = set.union(
266-
*[nx.ancestors(ancestor_graph.graph, treatment).union({treatment}) for treatment in treatments]
267-
)
268-
outcome_ancestors = set.union(
269-
*[nx.ancestors(ancestor_graph.graph, outcome).union({outcome}) for outcome in outcomes]
270-
)
271-
variables_to_keep = treatment_ancestors.union(outcome_ancestors)
272-
variables_to_remove = set(self.nodes).difference(variables_to_keep)
273-
ancestor_graph.graph.remove_nodes_from(variables_to_remove)
274-
return ancestor_graph
243+
variables_to_keep = {
244+
ancestor for var in treatments + outcomes for ancestor in nx.ancestors(self, var).union({var})
245+
}
246+
return self.subgraph(variables_to_keep)
275247

276248
def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
277249
"""
@@ -283,14 +255,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
283255
:return: The indirect graph with edges pointing from X to Y removed.
284256
:rtype: CausalDAG
285257
"""
286-
gback = self.copy()
287-
ee = []
288-
for s in treatments:
289-
for t in outcomes:
290-
if (s, t) in gback.edges:
291-
ee.append((s, t))
292-
for v1, v2 in ee:
293-
gback.graph.remove_edge(v1, v2)
258+
ee = {(s, t) for s in treatments for t in outcomes if (s, t) in self.edges}
259+
gback = CausalDAG()
260+
gback.add_edges_from(filter(lambda x: x not in ee, self.edges))
294261
return gback
295262

296263
def direct_effect_adjustment_sets(
@@ -319,7 +286,7 @@ def direct_effect_adjustment_sets(
319286

320287
indirect_graph = self.get_indirect_graph(treatments, outcomes)
321288
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
322-
gam = nx.moral_graph(ancestor_graph.graph)
289+
gam = nx.moral_graph(ancestor_graph)
323290

324291
edges_to_add = [("TREATMENT", treatment) for treatment in treatments]
325292
edges_to_add += [("OUTCOME", outcome) for outcome in outcomes]
@@ -354,7 +321,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
354321
# Step 1: Build the proper back-door graph and its moralized ancestor graph
355322
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
356323
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
357-
moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph.graph)
324+
moralised_proper_backdoor_graph = nx.moral_graph(ancestor_proper_backdoor_graph)
358325

359326
# Step 2: Add artificial TREATMENT and OUTCOME nodes
360327
moralised_proper_backdoor_graph.add_edges_from([("TREATMENT", t) for t in treatments])
@@ -453,7 +420,7 @@ def constructive_backdoor_criterion(
453420
if proper_path_vars:
454421
# Collect all descendants including each proper causal path var itself
455422
descendents_of_proper_casual_paths = set(proper_path_vars).union(
456-
{node for var in proper_path_vars for node in nx.descendants(self.graph, var)}
423+
{node for var in proper_path_vars for node in nx.descendants(self, var)}
457424
)
458425

459426
if not set(covariates).issubset(set(self.nodes).difference(descendents_of_proper_casual_paths)):
@@ -468,7 +435,7 @@ def constructive_backdoor_criterion(
468435
return False
469436

470437
# Condition (2): Z must d-separate X and Y in the proper back-door graph
471-
if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)):
438+
if not nx.d_separated(proper_backdoor_graph, set(treatments), set(outcomes), set(covariates)):
472439
logger.info(
473440
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph.",
474441
covariates,
@@ -492,7 +459,7 @@ def proper_causal_pathway(self, treatments: list[str], outcomes: list[str]) -> l
492459
treatments and outcomes.
493460
"""
494461
treatments_descendants = set.union(
495-
*[nx.descendants(self.graph, treatment).union({treatment}) for treatment in treatments]
462+
*[nx.descendants(self, treatment).union({treatment}) for treatment in treatments]
496463
)
497464
treatments_descendants_without_treatments = set(treatments_descendants).difference(treatments)
498465
backdoor_graph = self.get_backdoor_graph(set(treatments))
@@ -507,9 +474,9 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG:
507474
:param treatments: The set of treatments whose outgoing edges will be deleted.
508475
:return: A back-door graph corresponding to the given causal DAG and set of treatments.
509476
"""
510-
outgoing_edges = self.graph.out_edges(treatments)
511-
backdoor_graph = self.graph.copy()
512-
backdoor_graph.remove_edges_from(outgoing_edges)
477+
outgoing_edges = self.out_edges(treatments)
478+
backdoor_graph = CausalDAG()
479+
backdoor_graph.add_edges_from(filter(lambda x: x not in outgoing_edges, self.edges))
513480
return backdoor_graph
514481

515482
def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
@@ -526,7 +493,7 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
526493
"""
527494
if isinstance(scenario.variables[node], Output):
528495
return True
529-
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))
496+
return any((self.depends_on_outputs(n, scenario) for n in self.predecessors(node)))
530497

531498
@staticmethod
532499
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
@@ -546,7 +513,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
546513
estimate as opposed to a purely associational estimate.
547514
"""
548515
if self.ignore_cycles:
549-
return set(self.graph.predecessors(base_test_case.treatment_variable.name))
516+
return set(self.predecessors(base_test_case.treatment_variable.name))
550517
minimal_adjustment_sets = []
551518
if base_test_case.effect == "total":
552519
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(

0 commit comments

Comments
 (0)