Skip to content

Commit ca83012

Browse files
authored
Merge pull request #72 from CITCOM-project/total-effect-id
Direct effect
2 parents d0742eb + 08df6ae commit ca83012

File tree

7 files changed

+177
-7
lines changed

7 files changed

+177
-7
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .scenario import Scenario
1212
from .variable import Output
1313

14-
1514
def list_all_min_sep(graph: nx.Graph, treatment_node: Node, outcome_node: Node,
1615
treatment_node_set: set[Node], outcome_node_set: set[Node]):
1716
""" A backtracking algorithm for listing all minimal treatment-outcome separators in an undirected graph.
@@ -219,6 +218,56 @@ def get_ancestor_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
219218
ancestor_graph.graph.remove_nodes_from(variables_to_remove)
220219
return ancestor_graph
221220

221+
def get_indirect_graph(self, treatments:list[str], outcomes:list[str]) -> CausalDAG:
222+
"""
223+
This is the counterpart of the back-door graph for direct effects. We remove only edges pointing from X to Y.
224+
It is a Python implementation of the indirectGraph function from Dagitty.
225+
226+
:param list[str] treatments: List of treatment names.
227+
:param list[str] outcomes: List of outcome names.
228+
:return: The indirect graph with edges pointing from X to Y removed.
229+
:rtype: CausalDAG
230+
"""
231+
gback = self.copy()
232+
ee = []
233+
for s in treatments:
234+
for t in outcomes:
235+
if (s, t) in gback.graph.edges:
236+
ee.append((s, t))
237+
for v1, v2 in ee:
238+
gback.graph.remove_edge(v1,v2)
239+
return gback
240+
241+
def direct_effect_adjustment_sets(self, treatments:list[str], outcomes:list[str]) -> list[set[str]]:
242+
"""
243+
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
244+
and outcomes for DIRECT causal effect.
245+
246+
This is an Python implementation of the listMsasTotalEffect function from Dagitty using Algorithms presented in
247+
Adjustment Criteria in Causal Diagrams: An Algorithmic Perspective, Textor and Lískiewicz, 2012 and extended in
248+
Separators and adjustment sets in causal graphs: Complete criteria and an algorithmic framework, Zander et al.,
249+
2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal,
250+
backtracking algorithms to list the minimal vertex separators of a graph, 2013.
251+
252+
:param list[str] treatments: List of treatment names.
253+
:param list[str] outcomes: List of outcome names.
254+
:return: A list of possible adjustment sets.
255+
:rtype: list[set[str]]
256+
"""
257+
258+
indirect_graph = self.get_indirect_graph(treatments, outcomes)
259+
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
260+
gam = nx.moral_graph(ancestor_graph.graph)
261+
262+
edges_to_add = [("TREATMENT", treatment) for treatment in treatments]
263+
edges_to_add += [("OUTCOME", outcome) for outcome in outcomes]
264+
gam.add_edges_from(edges_to_add)
265+
266+
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
267+
# min_seps.remove(set(outcomes))
268+
return min_seps
269+
270+
222271
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
223272
""" Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
224273
and outcomes.

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class CausalTestCase:
1818

1919
def __init__(self, control_input_configuration: dict[Variable: Any], expected_causal_effect: CausalTestOutcome,
2020
outcome_variables: dict[Variable], treatment_input_configuration: dict[Variable: Any] = None,
21-
estimate_type: str = "ate", effect_modifier_configuration: dict[Variable: Any] = None):
21+
estimate_type: str = "ate", effect_modifier_configuration: dict[Variable: Any] = None, effect: str = "total"):
2222
"""
2323
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
2424
create two distinct input configurations: a control input configuration and a treatment input configuration.
@@ -35,6 +35,7 @@ def __init__(self, control_input_configuration: dict[Variable: Any], expected_ca
3535
self.outcome_variables = outcome_variables
3636
self.treatment_input_configuration = treatment_input_configuration
3737
self.estimate_type = estimate_type
38+
self.effect = effect
3839
if effect_modifier_configuration:
3940
self.effect_modifier_configuration = effect_modifier_configuration
4041
else:

causal_testing/testing/causal_test_engine.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,20 @@ def load_data(self, **kwargs):
5959

6060
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
6161

62-
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(
63-
[v.name for v in self.treatment_variables],
64-
[v.name for v in self.causal_test_case.outcome_variables]
65-
)
62+
minimal_adjustment_sets = []
63+
if self.causal_test_case.effect == "total":
64+
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(
65+
[v.name for v in self.treatment_variables],
66+
[v.name for v in self.causal_test_case.outcome_variables]
67+
)
68+
elif self.causal_test_case.effect == "direct":
69+
minimal_adjustment_sets = self.casual_dag.direct_effect_adjustment_sets(
70+
[v.name for v in self.treatment_variables],
71+
[v.name for v in self.causal_test_case.outcome_variables]
72+
)
73+
else:
74+
raise ValueError("Causal effect should be 'total' or 'direct'")
75+
6676
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
6777
return minimal_adjustment_set
6878

causal_testing/testing/causal_test_outcome.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def apply(self, res: CausalTestResult) -> bool:
118118
# TODO: confidence intervals?
119119
return res.ate < 0
120120

121+
class SomeEffect(CausalTestOutcome):
122+
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
123+
124+
def apply(self, res: CausalTestResult) -> bool:
125+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
126+
127+
def __str__(self):
128+
return "Changed"
121129

122130
class NoEffect(CausalTestOutcome):
123131
"""An extension of TestOutcome representing that the expected causal effect should be zero."""

tests/specification_tests/test_causal_dag.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,32 @@ def tearDown(self) -> None:
6464
remove_temp_dir_if_existent()
6565

6666

67+
class TestDAGDirectEffectIdentification(unittest.TestCase):
68+
"""
69+
Test the Causal DAG identification algorithms and supporting algorithms.
70+
"""
71+
72+
def setUp(self) -> None:
73+
temp_dir_path = create_temp_dir_if_non_existent()
74+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
75+
dag_dot = (
76+
"""digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
77+
)
78+
f = open(self.dag_dot_path, "w")
79+
f.write(dag_dot)
80+
f.close()
81+
82+
def test_direct_effect_adjustment_sets(self):
83+
causal_dag = CausalDAG(self.dag_dot_path)
84+
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X1"], ["Y"])
85+
self.assertEqual(list(adjustment_sets), [{"Y"}, {"D1", "Z"}, {"X2", "Z"}])
86+
87+
def test_direct_effect_adjustment_sets_no_adjustment(self):
88+
causal_dag = CausalDAG(self.dag_dot_path)
89+
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X2"], ["D1"])
90+
self.assertEqual(list(adjustment_sets), [set()])
91+
92+
6793
class TestDAGIdentification(unittest.TestCase):
6894
"""
6995
Test the Causal DAG identification algorithms and supporting algorithms.
@@ -79,6 +105,15 @@ def setUp(self) -> None:
79105
f.write(dag_dot)
80106
f.close()
81107

108+
def test_get_indirect_graph(self):
109+
causal_dag = CausalDAG(self.dag_dot_path)
110+
indirect_graph = causal_dag.get_indirect_graph(["D1"], ["Y"])
111+
original_edges = list(causal_dag.graph.edges)
112+
original_edges.remove(("D1", "Y"))
113+
self.assertEqual(list(indirect_graph.graph.edges), original_edges)
114+
self.assertEqual(indirect_graph.graph.nodes, causal_dag.graph.nodes)
115+
116+
82117
def test_proper_backdoor_graph(self):
83118
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
84119
causal_dag = CausalDAG(self.dag_dot_path)

tests/testing_tests/test_causal_test_engine.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ def test_execute_test_observational_causal_forest_estimator(self):
141141
causal_test_result = self.causal_test_engine.execute_test(estimation_model)
142142
self.assertAlmostEqual(causal_test_result.ate, 4, delta=1)
143143

144+
145+
def test_invalid_causal_effect(self):
146+
""" Check that executing the causal test case returns the correct results for dummy data using a linear
147+
regression estimator. """
148+
causal_test_case = CausalTestCase(
149+
control_input_configuration={self.A: 0},
150+
expected_causal_effect=self.expected_causal_effect,
151+
treatment_input_configuration={self.A: 1},
152+
outcome_variables={self.C},
153+
effect="error")
154+
# 5. Create causal test engine
155+
causal_test_engine = CausalTestEngine(
156+
causal_test_case,
157+
self.causal_specification,
158+
self.data_collector
159+
)
160+
with self.assertRaises(Exception):
161+
causal_test_engine.load_data()
162+
163+
144164
def test_execute_test_observational_linear_regression_estimator(self):
145165
""" Check that executing the causal test case returns the correct results for dummy data using a linear
146166
regression estimator. """
@@ -153,6 +173,37 @@ def test_execute_test_observational_linear_regression_estimator(self):
153173
causal_test_result = self.causal_test_engine.execute_test(estimation_model)
154174
self.assertAlmostEqual(causal_test_result.ate, 4, delta=1e-10)
155175

176+
177+
def test_execute_test_observational_linear_regression_estimator_direct_effect(self):
178+
""" Check that executing the causal test case returns the correct results for dummy data using a linear
179+
regression estimator. """
180+
causal_test_case = CausalTestCase(
181+
control_input_configuration={self.A: 0},
182+
expected_causal_effect=self.expected_causal_effect,
183+
treatment_input_configuration={self.A: 1},
184+
outcome_variables={self.C},
185+
effect="direct")
186+
187+
# 5. Create causal test engine
188+
causal_test_engine = CausalTestEngine(
189+
causal_test_case,
190+
self.causal_specification,
191+
self.data_collector
192+
)
193+
self.minimal_adjustment_set = causal_test_engine.load_data()
194+
195+
# 6. Easier to access treatment and outcome values
196+
self.treatment_value = 1
197+
self.control_value = 0
198+
estimation_model = LinearRegressionEstimator(('A',),
199+
self.treatment_value,
200+
self.control_value,
201+
self.minimal_adjustment_set,
202+
('C',),
203+
causal_test_engine.scenario_execution_data_df)
204+
causal_test_result = causal_test_engine.execute_test(estimation_model)
205+
self.assertAlmostEqual(causal_test_result.ate, 0, delta=1e-10)
206+
156207
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
157208
""" Check that executing the causal test case returns the correct results for dummy data using a linear
158209
regression estimator. """

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from causal_testing.testing.causal_test_outcome import CausalTestResult, ExactValue
2+
from causal_testing.testing.causal_test_outcome import CausalTestResult, ExactValue, SomeEffect
33

44
class TestCausalTestOutcome(unittest.TestCase):
55
""" Test the TestCausalTestOutcome basic methods.
@@ -28,3 +28,19 @@ def test_exactValue_fail(self):
2828
confidence_intervals = None, effect_modifier_configuration = None)
2929
ev = ExactValue(5, 0.1)
3030
self.assertFalse(ev.apply(ctr))
31+
32+
33+
def test_someEffect_pass(self):
34+
ctr = CausalTestResult(treatment="A", outcome="A", treatment_value=1,
35+
control_value=0, adjustment_set={}, ate=5.05,
36+
confidence_intervals = [4.8, 6.7], effect_modifier_configuration = None)
37+
ev = SomeEffect()
38+
self.assertTrue(ev.apply(ctr))
39+
40+
41+
def test_someEffect_fail(self):
42+
ctr = CausalTestResult(treatment="A", outcome="A", treatment_value=1,
43+
control_value=0, adjustment_set={}, ate=0,
44+
confidence_intervals = [-0.1, 0.2], effect_modifier_configuration = None)
45+
ev = SomeEffect()
46+
self.assertFalse(ev.apply(ctr))

0 commit comments

Comments
 (0)