Skip to content

Commit 59f0657

Browse files
committed
Separate test for to dot
1 parent 3841f8c commit 59f0657

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,10 @@ def to_dot(self) -> str:
525525
"""Return a string of the DOT representation of the causal DAG.
526526
:return DOT string of the DAG.
527527
"""
528-
return str(nx.nx_pydot.to_pydot(self.graph))
528+
dotstring = "digraph G {\n"
529+
dotstring += "".join([f"{a} -> {b};\n" for a, b in self.graph.edges])
530+
dotstring += "}"
531+
return dotstring
529532

530533
def __str__(self):
531534
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"

tests/json_front_tests/test_json_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def test_generate_coefficient_tests_from_json(self):
116116
}
117117
]
118118
}
119-
print(self.json_class.causal_specification.causal_dag.to_dot())
120119
self.json_class.test_plan = example_test
121120
effects = {"NoEffect": NoEffect()}
122121
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}

tests/specification_tests/test_causal_dag.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class TestCausalDAG(unittest.TestCase):
7373
def setUp(self) -> None:
7474
temp_dir_path = create_temp_dir_if_non_existent()
7575
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
76-
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C}"""
76+
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C;}"""
7777
f = open(self.dag_dot_path, "w")
7878
f.write(dag_dot)
7979
f.close()
@@ -99,6 +99,10 @@ def test_empty_casual_dag(self):
9999
causal_dag = CausalDAG()
100100
assert list(causal_dag.graph.nodes) == [] and list(causal_dag.graph.edges) == []
101101

102+
def test_to_dot(self):
103+
causal_dag = CausalDAG(self.dag_dot_path)
104+
self.assertEqual(causal_dag.to_dot(), """digraph G {\nA -> B;\nB -> C;\nD -> A;\nD -> C;\n}""")
105+
102106
def tearDown(self) -> None:
103107
remove_temp_dir_if_existent()
104108

0 commit comments

Comments
 (0)