Skip to content

Commit d2fdc91

Browse files
author
AndrewC19
committed
Apply constructive back-door criterion to minimal separators
1 parent 8122eb4 commit d2fdc91

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def is_acyclic(self) -> bool:
160160
def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
161161
"""Convert the causal DAG to a proper back-door graph.
162162
163-
A proper back-door graph of a causal DAG is obtained by
164-
removing the first edge of every proper causal path from treatments to outcomes. A proper causal path from
165-
X to Y is a path of directed edges that starts from X and ends in Y.
163+
A proper back-door graph of a causal DAG is obtained by removing the first edge of every proper causal path from
164+
treatments to outcomes. A proper causal path from X to Y is a path of directed edges that starts from X and ends
165+
in Y.
166166
167167
Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
168168
Zander et al., 2019, Definition 3, p.15)
@@ -311,7 +311,10 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
311311
outcome_node_set,
312312
)
313313
)
314-
return minimum_adjustment_sets
314+
valid_minimum_adjustment_sets = [adj for adj in minimum_adjustment_sets
315+
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments,
316+
outcomes, adj)]
317+
return valid_minimum_adjustment_sets
315318

316319
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:
317320
"""Given a list of treatments X, a list of outcomes Y, and an adjustment set Z, determine whether Z is the
@@ -375,22 +378,23 @@ def constructive_backdoor_criterion(
375378
"""
376379
# Condition (1)
377380
proper_causal_path_vars = self.proper_causal_pathway(treatments, outcomes)
378-
descendents_of_proper_casual_paths = set.union(
379-
*[
380-
set.union(
381-
nx.descendants(self.graph, proper_causal_path_var),
382-
{proper_causal_path_var},
383-
)
384-
for proper_causal_path_var in proper_causal_path_vars
385-
]
386-
)
387-
388-
if not set(covariates).issubset(set(self.graph.nodes).difference(descendents_of_proper_casual_paths)):
389-
logger.info(
390-
f"Failed Condition 1: Z={covariates} **is** a descendent of some variable on a proper causal "
391-
f"path between X={treatments} and Y={outcomes}."
381+
if proper_causal_path_vars:
382+
descendents_of_proper_casual_paths = set.union(
383+
*[
384+
set.union(
385+
nx.descendants(self.graph, proper_causal_path_var),
386+
{proper_causal_path_var},
387+
)
388+
for proper_causal_path_var in proper_causal_path_vars
389+
]
392390
)
393-
return False
391+
392+
if not set(covariates).issubset(set(self.graph.nodes).difference(descendents_of_proper_casual_paths)):
393+
logger.info(
394+
f"Failed Condition 1: Z={covariates} **is** a descendent of some variable on a proper causal "
395+
f"path between X={treatments} and Y={outcomes}."
396+
)
397+
return False
394398

395399
# Condition (2)
396400
if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)):

tests/specification_tests/test_causal_dag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep
55
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
66

7+
78
class TestCausalDAGIssue90(unittest.TestCase):
89

910
"""
@@ -13,16 +14,16 @@ class TestCausalDAGIssue90(unittest.TestCase):
1314
def setUp(self) -> None:
1415
temp_dir_path = create_temp_dir_if_non_existent()
1516
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
16-
dag_dot = """digraph DAG { rankdir=LR; content -> weight; weight -> S3; country -> S3; country -> distance; content -> S3; plane_transport -> S1; plane_transport -> S2; S1 -> alarm; S2 -> alarm; S3 -> alarm; }"""
17+
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
1718
with open(self.dag_dot_path, "w") as f:
1819
f.write(dag_dot)
1920

2021
def test_enumerate_minimal_adjustment_sets(self):
2122
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
2223
causal_dag = CausalDAG(self.dag_dot_path)
23-
xs, ys = ["weight"], ["alarm"]
24+
xs, ys = ["X"], ["Y"]
2425
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
25-
self.assertEqual([{"content"}], adjustment_sets)
26+
self.assertEqual([{"Z"}], adjustment_sets)
2627

2728
def tearDown(self) -> None:
2829
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)