@@ -160,9 +160,9 @@ def is_acyclic(self) -> bool:
160
160
def get_proper_backdoor_graph (self , treatments : list [str ], outcomes : list [str ]) -> CausalDAG :
161
161
"""Convert the causal DAG to a proper back-door graph.
162
162
163
- A proper back-door graph of a causal DAG is obtained by
164
- removing the first edge of every proper causal path from treatments to outcomes. A proper causal path from
165
- X to Y is a path of directed edges that starts from X and ends in Y.
163
+ A proper back-door graph of a causal DAG is obtained by removing the first edge of every proper causal path from
164
+ treatments to outcomes. A proper causal path from X to Y is a path of directed edges that starts from X and ends
165
+ in Y.
166
166
167
167
Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
168
168
Zander et al., 2019, Definition 3, p.15)
@@ -311,7 +311,10 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
311
311
outcome_node_set ,
312
312
)
313
313
)
314
- return minimum_adjustment_sets
314
+ valid_minimum_adjustment_sets = [adj for adj in minimum_adjustment_sets
315
+ if self .constructive_backdoor_criterion (proper_backdoor_graph , treatments ,
316
+ outcomes , adj )]
317
+ return valid_minimum_adjustment_sets
315
318
316
319
def adjustment_set_is_minimal (self , treatments : list [str ], outcomes : list [str ], adjustment_set : set [str ]) -> bool :
317
320
"""Given a list of treatments X, a list of outcomes Y, and an adjustment set Z, determine whether Z is the
@@ -375,22 +378,23 @@ def constructive_backdoor_criterion(
375
378
"""
376
379
# Condition (1)
377
380
proper_causal_path_vars = self .proper_causal_pathway (treatments , outcomes )
378
- descendents_of_proper_casual_paths = set .union (
379
- * [
380
- set .union (
381
- nx .descendants (self .graph , proper_causal_path_var ),
382
- {proper_causal_path_var },
383
- )
384
- for proper_causal_path_var in proper_causal_path_vars
385
- ]
386
- )
387
-
388
- if not set (covariates ).issubset (set (self .graph .nodes ).difference (descendents_of_proper_casual_paths )):
389
- logger .info (
390
- f"Failed Condition 1: Z={ covariates } **is** a descendent of some variable on a proper causal "
391
- f"path between X={ treatments } and Y={ outcomes } ."
381
+ if proper_causal_path_vars :
382
+ descendents_of_proper_casual_paths = set .union (
383
+ * [
384
+ set .union (
385
+ nx .descendants (self .graph , proper_causal_path_var ),
386
+ {proper_causal_path_var },
387
+ )
388
+ for proper_causal_path_var in proper_causal_path_vars
389
+ ]
392
390
)
393
- return False
391
+
392
+ if not set (covariates ).issubset (set (self .graph .nodes ).difference (descendents_of_proper_casual_paths )):
393
+ logger .info (
394
+ f"Failed Condition 1: Z={ covariates } **is** a descendent of some variable on a proper causal "
395
+ f"path between X={ treatments } and Y={ outcomes } ."
396
+ )
397
+ return False
394
398
395
399
# Condition (2)
396
400
if not nx .d_separated (proper_backdoor_graph .graph , set (treatments ), set (outcomes ), set (covariates )):
0 commit comments