Skip to content

Commit 8e8fd8b

Browse files
committed
Testing IV assumptions
1 parent db004c5 commit 8e8fd8b

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ def __init__(self, dot_path: str = None, **attr):
138138
if not self.is_acyclic():
139139
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
140140

141+
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
142+
"""
143+
Checks the three instrumental variable assumptions, raising a
144+
ValueError if any are violated.
145+
146+
:return Boolean True if the three IV assumptions hold.
147+
"""
148+
# (i) Instrument is associated with treatment
149+
if nx.d_separated(self.graph, {instrument}, {treatment}, set()):
150+
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151+
152+
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153+
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
154+
raise ValueError(
155+
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156+
)
157+
158+
# (iii) Instrument and outcome do not share causes
159+
if any(
160+
[
161+
cause
162+
for cause in self.graph.nodes
163+
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164+
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165+
]
166+
):
167+
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168+
169+
return True
170+
141171
def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
142172
"""Add an edge to the causal DAG.
143173

tests/specification_tests/test_causal_dag.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,38 @@ def tearDown(self) -> None:
2929
remove_temp_dir_if_existent()
3030

3131

32+
class TestIVAssumptions(unittest.TestCase):
33+
def setUp(self) -> None:
34+
temp_dir_path = create_temp_dir_if_non_existent()
35+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
36+
dag_dot = """digraph G { I -> X; X -> Y; U -> X; U -> Y;}"""
37+
f = open(self.dag_dot_path, "w")
38+
f.write(dag_dot)
39+
f.close()
40+
41+
def test_valid_iv(self):
42+
causal_dag = CausalDAG(self.dag_dot_path)
43+
self.assertTrue(causal_dag.check_iv_assumptions("X", "Y", "I"))
44+
45+
def test_unrelated_instrument(self):
46+
causal_dag = CausalDAG(self.dag_dot_path)
47+
causal_dag.graph.remove_edge("I", "X")
48+
with self.assertRaises(ValueError):
49+
causal_dag.check_iv_assumptions("X", "Y", "I")
50+
51+
def test_direct_cause(self):
52+
causal_dag = CausalDAG(self.dag_dot_path)
53+
causal_dag.graph.add_edge("I", "Y")
54+
with self.assertRaises(ValueError):
55+
causal_dag.check_iv_assumptions("X", "Y", "I")
56+
57+
def test_common_cause(self):
58+
causal_dag = CausalDAG(self.dag_dot_path)
59+
causal_dag.graph.add_edge("U", "I")
60+
with self.assertRaises(ValueError):
61+
causal_dag.check_iv_assumptions("X", "Y", "I")
62+
63+
3264
class TestCausalDAG(unittest.TestCase):
3365

3466
"""

0 commit comments

Comments
 (0)