Skip to content

Commit 1afaedd

Browse files
Merge branch 'main' into test-coverage-json
2 parents f50af43 + 030b655 commit 1afaedd

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def is_acyclic(self) -> bool:
160160
def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
161161
"""Convert the causal DAG to a proper back-door graph.
162162
163-
A proper back-door graph of a causal DAG is obtained by
164-
removing the first edge of every proper causal path from treatments to outcomes. A proper causal path from
165-
X to Y is a path of directed edges that starts from X and ends in Y.
163+
A proper back-door graph of a causal DAG is obtained by removing the first edge of every proper causal path from
164+
treatments to outcomes. A proper causal path from X to Y is a path of directed edges that starts from X and ends
165+
in Y.
166166
167167
Reference: (Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework,
168168
Zander et al., 2019, Definition 3, p.15)
@@ -311,7 +311,12 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
311311
outcome_node_set,
312312
)
313313
)
314-
return minimum_adjustment_sets
314+
valid_minimum_adjustment_sets = [
315+
adj
316+
for adj in minimum_adjustment_sets
317+
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318+
]
319+
return valid_minimum_adjustment_sets
315320

316321
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:
317322
"""Given a list of treatments X, a list of outcomes Y, and an adjustment set Z, determine whether Z is the
@@ -375,22 +380,23 @@ def constructive_backdoor_criterion(
375380
"""
376381
# Condition (1)
377382
proper_causal_path_vars = self.proper_causal_pathway(treatments, outcomes)
378-
descendents_of_proper_casual_paths = set.union(
379-
*[
380-
set.union(
381-
nx.descendants(self.graph, proper_causal_path_var),
382-
{proper_causal_path_var},
383-
)
384-
for proper_causal_path_var in proper_causal_path_vars
385-
]
386-
)
387-
388-
if not set(covariates).issubset(set(self.graph.nodes).difference(descendents_of_proper_casual_paths)):
389-
logger.info(
390-
f"Failed Condition 1: Z={covariates} **is** a descendent of some variable on a proper causal "
391-
f"path between X={treatments} and Y={outcomes}."
383+
if proper_causal_path_vars:
384+
descendents_of_proper_casual_paths = set.union(
385+
*[
386+
set.union(
387+
nx.descendants(self.graph, proper_causal_path_var),
388+
{proper_causal_path_var},
389+
)
390+
for proper_causal_path_var in proper_causal_path_vars
391+
]
392392
)
393-
return False
393+
394+
if not set(covariates).issubset(set(self.graph.nodes).difference(descendents_of_proper_casual_paths)):
395+
logger.info(
396+
f"Failed Condition 1: Z={covariates} **is** a descendent of some variable on a proper causal "
397+
f"path between X={treatments} and Y={outcomes}."
398+
)
399+
return False
394400

395401
# Condition (2)
396402
if not nx.d_separated(proper_backdoor_graph.graph, set(treatments), set(outcomes), set(covariates)):

tests/specification_tests/test_causal_dag.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,30 @@
55
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
66

77

8+
class TestCausalDAGIssue90(unittest.TestCase):
9+
10+
"""
11+
Test the CausalDAG class for the resolution of Issue 90.
12+
"""
13+
14+
def setUp(self) -> None:
15+
temp_dir_path = create_temp_dir_if_non_existent()
16+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
17+
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
18+
with open(self.dag_dot_path, "w") as f:
19+
f.write(dag_dot)
20+
21+
def test_enumerate_minimal_adjustment_sets(self):
22+
"""Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
23+
causal_dag = CausalDAG(self.dag_dot_path)
24+
xs, ys = ["X"], ["Y"]
25+
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
26+
self.assertEqual([{"Z"}], adjustment_sets)
27+
28+
def tearDown(self) -> None:
29+
remove_temp_dir_if_existent()
30+
31+
832
class TestCausalDAG(unittest.TestCase):
933

1034
"""

0 commit comments

Comments
 (0)