@@ -160,9 +160,9 @@ def is_acyclic(self) -> bool:
160160 def get_proper_backdoor_graph (self , treatments : list [str ], outcomes : list [str ]) -> CausalDAG :
161161 """Convert the causal DAG to a proper back-door graph.
162162
163- A proper back-door graph of a causal DAG is obtained by
164- removing the first edge of every proper causal path from treatments to outcomes. A proper causal path from
165- X to Y is a path of directed edges that starts from X and ends in Y.
163+ A proper back-door graph of a causal DAG is obtained by removing the first edge of every proper causal path from
164+ treatments to outcomes. A proper causal path from X to Y is a path of directed edges that starts from X and ends
165+ in Y.
166166
167167 Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
168168 Zander et al., 2019, Definition 3, p.15)
@@ -311,7 +311,12 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
311311 outcome_node_set ,
312312 )
313313 )
314- return minimum_adjustment_sets
314+ valid_minimum_adjustment_sets = [
315+ adj
316+ for adj in minimum_adjustment_sets
317+ if self .constructive_backdoor_criterion (proper_backdoor_graph , treatments , outcomes , adj )
318+ ]
319+ return valid_minimum_adjustment_sets
315320
316321 def adjustment_set_is_minimal (self , treatments : list [str ], outcomes : list [str ], adjustment_set : set [str ]) -> bool :
317322 """Given a list of treatments X, a list of outcomes Y, and an adjustment set Z, determine whether Z is the
@@ -375,22 +380,23 @@ def constructive_backdoor_criterion(
375380 """
376381 # Condition (1)
377382 proper_causal_path_vars = self .proper_causal_pathway (treatments , outcomes )
378- descendents_of_proper_casual_paths = set .union (
379- * [
380- set .union (
381- nx .descendants (self .graph , proper_causal_path_var ),
382- {proper_causal_path_var },
383- )
384- for proper_causal_path_var in proper_causal_path_vars
385- ]
386- )
387-
388- if not set (covariates ).issubset (set (self .graph .nodes ).difference (descendents_of_proper_casual_paths )):
389- logger .info (
390- f"Failed Condition 1: Z={ covariates } **is** a descendent of some variable on a proper causal "
391- f"path between X={ treatments } and Y={ outcomes } ."
383+ if proper_causal_path_vars :
384+ descendents_of_proper_casual_paths = set .union (
385+ * [
386+ set .union (
387+ nx .descendants (self .graph , proper_causal_path_var ),
388+ {proper_causal_path_var },
389+ )
390+ for proper_causal_path_var in proper_causal_path_vars
391+ ]
392392 )
393- return False
393+
394+ if not set (covariates ).issubset (set (self .graph .nodes ).difference (descendents_of_proper_casual_paths )):
395+ logger .info (
396+ f"Failed Condition 1: Z={ covariates } **is** a descendent of some variable on a proper causal "
397+ f"path between X={ treatments } and Y={ outcomes } ."
398+ )
399+ return False
394400
395401 # Condition (2)
396402 if not nx .d_separated (proper_backdoor_graph .graph , set (treatments ), set (outcomes ), set (covariates )):
0 commit comments