Skip to content

Commit 8bff2a9

Browse files
author
AndrewC19
committed
Fixed direct effect identification and test case
1 parent 5c57f49 commit 8bff2a9

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:

tests/specification_tests/test_causal_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def setUp(self) -> None:
106106
def test_direct_effect_adjustment_sets(self):
107107
causal_dag = CausalDAG(self.dag_dot_path)
108108
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X1"], ["Y"])
109-
self.assertEqual(list(adjustment_sets), [{"Y"}, {"D1", "Z"}, {"X2", "Z"}])
109+
self.assertEqual(list(adjustment_sets), [{"D1", "Z"}, {"X2", "Z"}])
110110

111111
def test_direct_effect_adjustment_sets_no_adjustment(self):
112112
causal_dag = CausalDAG(self.dag_dot_path)

0 commit comments

Comments
 (0)