@@ -26,18 +26,17 @@ class CausalDAG(nx.DiGraph):
26
26
ensures it is acyclic. A CausalDAG must be specified as a dot file.
27
27
"""
28
28
29
- def __init__ (self , dot_path : str = None , ignore_cycles : bool = False , ** attr ):
29
+ def __init__ (self , file_path : str = None , ignore_cycles : bool = False , ** attr ):
30
30
super ().__init__ (** attr )
31
31
self .ignore_cycles = ignore_cycles
32
- if dot_path :
33
- if dot_path .endswith (".dot" ):
34
- self . graph = nx .DiGraph (nx .nx_pydot .read_dot (dot_path ))
35
- elif dot_path .endswith (".xml" ):
36
- self . graph = nx .graphml .read_graphml (dot_path )
32
+ if file_path :
33
+ if file_path .endswith (".dot" ):
34
+ graph = nx .DiGraph (nx .nx_pydot .read_dot (file_path ))
35
+ elif file_path .endswith (".xml" ):
36
+ graph = nx .graphml .read_graphml (file_path )
37
37
else :
38
- raise ValueError (f"Unsupported file extension { dot_path } . We only support .dot and .xml files." )
39
- else :
40
- self .graph = nx .DiGraph ()
38
+ raise ValueError (f"Unsupported file extension { file_path } . We only support .dot and .xml files." )
39
+ self .update (graph )
41
40
42
41
if not self .is_acyclic ():
43
42
if ignore_cycles :
@@ -47,22 +46,6 @@ def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
47
46
else :
48
47
raise nx .HasACycle ("Invalid Causal DAG: contains a cycle." )
49
48
50
- @property
51
- def nodes (self ) -> list :
52
- """
53
- Get the nodes of the DAG.
54
- :returns: The nodes of the DAG.
55
- """
56
- return self .graph .nodes
57
-
58
- @property
59
- def edges (self ) -> list :
60
- """
61
- Get the edges of the DAG.
62
- :returns: The edges of the DAG.
63
- """
64
- return self .graph .edges
65
-
66
49
def close_separator (
67
50
self , graph : nx .Graph , treatment_node : Node , outcome_node : Node , treatment_node_set : set [Node ]
68
51
) -> set [Node ]:
@@ -173,11 +156,11 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
173
156
:return Boolean True if the three IV assumptions hold.
174
157
"""
175
158
# (i) Instrument is associated with treatment
176
- if nx .d_separated (self . graph , {instrument }, {treatment }, set ()):
159
+ if nx .d_separated (self , {instrument }, {treatment }, set ()):
177
160
raise ValueError (f"Instrument { instrument } is not associated with treatment { treatment } in the DAG" )
178
161
179
162
# (ii) Instrument does not affect outcome except through its potential effect on treatment
180
- if not all ((treatment in path for path in nx .all_simple_paths (self . graph , source = instrument , target = outcome ))):
163
+ if not all ((treatment in path for path in nx .all_simple_paths (self , source = instrument , target = outcome ))):
181
164
raise ValueError (
182
165
f"Instrument { instrument } affects the outcome { outcome } other than through the treatment { treatment } "
183
166
)
@@ -186,11 +169,9 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
186
169
187
170
for cause in self .nodes :
188
171
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
189
- outcome_paths = (
190
- list (nx .all_simple_paths (self .graph , source = cause , target = outcome )) if cause != outcome else []
191
- )
172
+ outcome_paths = list (nx .all_simple_paths (self , source = cause , target = outcome )) if cause != outcome else []
192
173
instrument_paths = (
193
- list (nx .all_simple_paths (self . graph , source = cause , target = instrument )) if cause != instrument else []
174
+ list (nx .all_simple_paths (self , source = cause , target = instrument )) if cause != instrument else []
194
175
)
195
176
if len (instrument_paths ) > 0 and len (outcome_paths ) > 0 :
196
177
raise ValueError (f"Instrument { instrument } and outcome { outcome } share common causes" )
@@ -204,15 +185,15 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
204
185
:param v_of_edge: To node
205
186
:param attr: Attributes
206
187
"""
207
- self .graph . add_edge (u_of_edge , v_of_edge , ** attr )
188
+ self .add_edge (u_of_edge , v_of_edge , ** attr )
208
189
if not self .is_acyclic ():
209
190
raise nx .HasACycle ("Invalid Causal DAG: contains a cycle." )
210
191
211
192
def cycle_nodes (self ) -> list :
212
193
"""Get the nodes involved in any cycles.
213
194
:return: A list containing all nodes involved in a cycle.
214
195
"""
215
- return [node for cycle in nx .simple_cycles (self . graph ) for node in cycle ]
196
+ return [node for cycle in nx .simple_cycles (self ) for node in cycle ]
216
197
217
198
def is_acyclic (self ) -> bool :
218
199
"""Checks if the graph is acyclic.
@@ -235,16 +216,14 @@ def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str])
235
216
:param outcomes: A list of outcomes.
236
217
:return: A CausalDAG corresponding to the proper back-door graph.
237
218
"""
238
- for var in treatments + outcomes :
239
- if var not in self .nodes :
240
- raise IndexError (f"{ var } not a node in Causal DAG.\n Valid nodes are{ self .nodes } ." )
241
-
242
- proper_backdoor_graph = self .copy ()
243
- nodes_on_proper_causal_path = proper_backdoor_graph .proper_causal_pathway (treatments , outcomes )
244
- edges_to_remove = [
245
- (u , v ) for (u , v ) in proper_backdoor_graph .graph .out_edges (treatments ) if v in nodes_on_proper_causal_path
246
- ]
247
- proper_backdoor_graph .graph .remove_edges_from (edges_to_remove )
219
+ assert set (treatments + outcomes ).issubset (
220
+ set (self .nodes )
221
+ ), f"Nodes { set (treatments + outcomes ).difference (set (self .nodes ))} not in causal DAG"
222
+
223
+ nodes_on_proper_causal_path = self .proper_causal_pathway (treatments , outcomes )
224
+ proper_backdoor_graph = CausalDAG ()
225
+ edges_to_remove = {(u , v ) for (u , v ) in self .out_edges (treatments ) if v in nodes_on_proper_causal_path }
226
+ proper_backdoor_graph .add_edges_from (e for e in self .edges if e not in edges_to_remove )
248
227
return proper_backdoor_graph
249
228
250
229
def get_ancestor_graph (self , treatments : list [str ], outcomes : list [str ]) -> CausalDAG :
@@ -261,17 +240,10 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
261
240
:param outcomes: A list of outcome variables to include in the ancestral graph (and their ancestors).
262
241
:return: An ancestral graph relative to the set of variables X union Y.
263
242
"""
264
- ancestor_graph = self .copy ()
265
- treatment_ancestors = set .union (
266
- * [nx .ancestors (ancestor_graph .graph , treatment ).union ({treatment }) for treatment in treatments ]
267
- )
268
- outcome_ancestors = set .union (
269
- * [nx .ancestors (ancestor_graph .graph , outcome ).union ({outcome }) for outcome in outcomes ]
270
- )
271
- variables_to_keep = treatment_ancestors .union (outcome_ancestors )
272
- variables_to_remove = set (self .nodes ).difference (variables_to_keep )
273
- ancestor_graph .graph .remove_nodes_from (variables_to_remove )
274
- return ancestor_graph
243
+ variables_to_keep = {
244
+ ancestor for var in treatments + outcomes for ancestor in nx .ancestors (self , var ).union ({var })
245
+ }
246
+ return self .subgraph (variables_to_keep )
275
247
276
248
def get_indirect_graph (self , treatments : list [str ], outcomes : list [str ]) -> CausalDAG :
277
249
"""
@@ -283,14 +255,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
283
255
:return: The indirect graph with edges pointing from X to Y removed.
284
256
:rtype: CausalDAG
285
257
"""
286
- gback = self .copy ()
287
- ee = []
288
- for s in treatments :
289
- for t in outcomes :
290
- if (s , t ) in gback .edges :
291
- ee .append ((s , t ))
292
- for v1 , v2 in ee :
293
- gback .graph .remove_edge (v1 , v2 )
258
+ ee = {(s , t ) for s in treatments for t in outcomes if (s , t ) in self .edges }
259
+ gback = CausalDAG ()
260
+ gback .add_edges_from (filter (lambda x : x not in ee , self .edges ))
294
261
return gback
295
262
296
263
def direct_effect_adjustment_sets (
@@ -319,7 +286,7 @@ def direct_effect_adjustment_sets(
319
286
320
287
indirect_graph = self .get_indirect_graph (treatments , outcomes )
321
288
ancestor_graph = indirect_graph .get_ancestor_graph (treatments , outcomes )
322
- gam = nx .moral_graph (ancestor_graph . graph )
289
+ gam = nx .moral_graph (ancestor_graph )
323
290
324
291
edges_to_add = [("TREATMENT" , treatment ) for treatment in treatments ]
325
292
edges_to_add += [("OUTCOME" , outcome ) for outcome in outcomes ]
@@ -354,7 +321,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
354
321
# Step 1: Build the proper back-door graph and its moralized ancestor graph
355
322
proper_backdoor_graph = self .get_proper_backdoor_graph (treatments , outcomes )
356
323
ancestor_proper_backdoor_graph = proper_backdoor_graph .get_ancestor_graph (treatments , outcomes )
357
- moralised_proper_backdoor_graph = nx .moral_graph (ancestor_proper_backdoor_graph . graph )
324
+ moralised_proper_backdoor_graph = nx .moral_graph (ancestor_proper_backdoor_graph )
358
325
359
326
# Step 2: Add artificial TREATMENT and OUTCOME nodes
360
327
moralised_proper_backdoor_graph .add_edges_from ([("TREATMENT" , t ) for t in treatments ])
@@ -453,7 +420,7 @@ def constructive_backdoor_criterion(
453
420
if proper_path_vars :
454
421
# Collect all descendants including each proper causal path var itself
455
422
descendents_of_proper_casual_paths = set (proper_path_vars ).union (
456
- {node for var in proper_path_vars for node in nx .descendants (self . graph , var )}
423
+ {node for var in proper_path_vars for node in nx .descendants (self , var )}
457
424
)
458
425
459
426
if not set (covariates ).issubset (set (self .nodes ).difference (descendents_of_proper_casual_paths )):
@@ -468,7 +435,7 @@ def constructive_backdoor_criterion(
468
435
return False
469
436
470
437
# Condition (2): Z must d-separate X and Y in the proper back-door graph
471
- if not nx .d_separated (proper_backdoor_graph . graph , set (treatments ), set (outcomes ), set (covariates )):
438
+ if not nx .d_separated (proper_backdoor_graph , set (treatments ), set (outcomes ), set (covariates )):
472
439
logger .info (
473
440
"Failed Condition 2: Z=%s **does not** d-separate X=%s and Y=%s in the proper back-door graph." ,
474
441
covariates ,
@@ -492,7 +459,7 @@ def proper_causal_pathway(self, treatments: list[str], outcomes: list[str]) -> l
492
459
treatments and outcomes.
493
460
"""
494
461
treatments_descendants = set .union (
495
- * [nx .descendants (self . graph , treatment ).union ({treatment }) for treatment in treatments ]
462
+ * [nx .descendants (self , treatment ).union ({treatment }) for treatment in treatments ]
496
463
)
497
464
treatments_descendants_without_treatments = set (treatments_descendants ).difference (treatments )
498
465
backdoor_graph = self .get_backdoor_graph (set (treatments ))
@@ -507,9 +474,9 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG:
507
474
:param treatments: The set of treatments whose outgoing edges will be deleted.
508
475
:return: A back-door graph corresponding to the given causal DAG and set of treatments.
509
476
"""
510
- outgoing_edges = self .graph . out_edges (treatments )
511
- backdoor_graph = self . graph . copy ()
512
- backdoor_graph .remove_edges_from ( outgoing_edges )
477
+ outgoing_edges = self .out_edges (treatments )
478
+ backdoor_graph = CausalDAG ()
479
+ backdoor_graph .add_edges_from ( filter ( lambda x : x not in outgoing_edges , self . edges ) )
513
480
return backdoor_graph
514
481
515
482
def depends_on_outputs (self , node : Node , scenario : Scenario ) -> bool :
@@ -526,7 +493,7 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
526
493
"""
527
494
if isinstance (scenario .variables [node ], Output ):
528
495
return True
529
- return any ((self .depends_on_outputs (n , scenario ) for n in self .graph . predecessors (node )))
496
+ return any ((self .depends_on_outputs (n , scenario ) for n in self .predecessors (node )))
530
497
531
498
@staticmethod
532
499
def remove_hidden_adjustment_sets (minimal_adjustment_sets : list [str ], scenario : Scenario ):
@@ -546,7 +513,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
546
513
estimate as opposed to a purely associational estimate.
547
514
"""
548
515
if self .ignore_cycles :
549
- return set (self .graph . predecessors (base_test_case .treatment_variable .name ))
516
+ return set (self .predecessors (base_test_case .treatment_variable .name ))
550
517
minimal_adjustment_sets = []
551
518
if base_test_case .effect == "total" :
552
519
minimal_adjustment_sets = self .enumerate_minimal_adjustment_sets (
0 commit comments