Skip to content

Commit 0a070b4

Browse files
committed
test causal dag patch
1 parent f2df2d1 commit 0a070b4

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
558558
estimate as opposed to a purely associational estimate.
559559
"""
560560
if self.ignore_cycles:
561-
return self.graph.predecessors(base_test_case.treatment_variable.name)
561+
return set(self.graph.predecessors(base_test_case.treatment_variable.name))
562562
minimal_adjustment_sets = []
563563
if base_test_case.effect == "total":
564564
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(
@@ -578,7 +578,7 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
578578
return set()
579579

580580
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
581-
return minimal_adjustment_set
581+
return set(minimal_adjustment_set)
582582

583583
def to_dot_string(self) -> str:
584584
"""Return a string of the DOT representation of the causal DAG.

tests/specification_tests/test_causal_dag.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def setUp(self) -> None:
8585
def test_valid_causal_dag(self):
8686
"""Test whether the Causal DAG is valid."""
8787
causal_dag = CausalDAG(self.dag_dot_path)
88-
print(causal_dag)
8988
assert list(causal_dag.nodes) == ["A", "B", "C", "D"] and list(causal_dag.edges) == [
9089
("A", "B"),
9190
("B", "C"),
@@ -127,6 +126,11 @@ def setUp(self) -> None:
127126
def test_invalid_causal_dag(self):
128127
self.assertRaises(nx.HasACycle, CausalDAG, self.dag_dot_path)
129128

129+
def test_ignore_cycles(self):
130+
dag = CausalDAG(self.dag_dot_path, ignore_cycles=True)
131+
base_test_case = BaseTestCase(Output("B", float), Output("C", float))
132+
self.assertEqual(dag.identification(base_test_case), {"A"})
133+
130134
def tearDown(self) -> None:
131135
shutil.rmtree(self.temp_dir_path)
132136

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_should_not_cause_json_stub(self):
5555
"estimate_type": "coefficient",
5656
"estimator": "LinearRegressionEstimator",
5757
"expected_effect": {"Z": "NoEffect"},
58-
"mutations": ["X1"],
58+
"treatment_variable": "X1",
5959
"name": "X1 _||_ Z",
6060
"formula": "Z ~ X1",
6161
"alpha": 0.05,
@@ -78,7 +78,7 @@ def test_should_cause_json_stub(self):
7878
"estimator": "LinearRegressionEstimator",
7979
"expected_effect": {"Z": "SomeEffect"},
8080
"formula": "Z ~ X1",
81-
"mutations": ["X1"],
81+
"treatment_variable": "X1",
8282
"name": "X1 --> Z",
8383
"skip": True,
8484
},

0 commit comments

Comments
 (0)