Skip to content

Commit 6b6a83d

Browse files
Merge branch 'main' into json_concrete_param
2 parents 7c185b4 + 35146f1 commit 6b6a83d

File tree

4 files changed

+19
-30
lines changed

4 files changed

+19
-30
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Union
99

1010
import networkx as nx
11+
import pydot
1112

1213
from causal_testing.testing.base_test_case import BaseTestCase
1314

@@ -133,7 +134,8 @@ class CausalDAG(nx.DiGraph):
133134
def __init__(self, dot_path: str = None, **attr):
134135
super().__init__(**attr)
135136
if dot_path:
136-
self.graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(dot_path))
137+
pydot_graph = pydot.graph_from_dot_file(dot_path)
138+
self.graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(pydot_graph[0]))
137139
else:
138140
self.graph = nx.DiGraph()
139141

docs/source/installation.rst

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,6 @@ Requirements
66
CausalTestingFramework requires python version 3.9 or later
77
If installing on Windows, ensure `Microsoft Visual C++ <https://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist>`_ is version 14.0 or greater
88

9-
Pygraphviz
10-
----------
11-
12-
Pygraphviz can be installed through the conda-forge channel::
13-
14-
conda install -c conda-forge pygraphviz
15-
16-
17-
Alternatively, on Linux systems, this can be done with `sudo apt install graphviz libgraphviz-dev`.
18-
199
Pip Install
2010
-----------
2111
To install the Causal Testing Framework using :code:`pip` for the latest stable version::

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ dependencies = [
2525
"scikit_learn~=1.1",
2626
"scipy~=1.7",
2727
"statsmodels~=0.13",
28-
"tabulate~=0.8"
28+
"tabulate~=0.8",
29+
"pydot~=1.4"
2930
]
3031
dynamic = ["version"]
3132

tests/specification_tests/test_causal_dag.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class TestCausalDAGIssue90(unittest.TestCase):
9-
109
"""
1110
Test the CausalDAG class for the resolution of Issue 90.
1211
"""
@@ -62,7 +61,6 @@ def test_common_cause(self):
6261

6362

6463
class TestCausalDAG(unittest.TestCase):
65-
6664
"""
6765
Test the CausalDAG class for creation of Causal Directed Acyclic Graphs (DAGs).
6866
@@ -176,19 +174,18 @@ def test_proper_backdoor_graph(self):
176174
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
177175
causal_dag = CausalDAG(self.dag_dot_path)
178176
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(["X1", "X2"], ["Y"])
179-
self.assertEqual(
180-
list(proper_backdoor_graph.graph.edges),
181-
[
182-
("X1", "X2"),
183-
("X2", "V"),
184-
("X2", "D2"),
185-
("D1", "D2"),
186-
("D1", "Y"),
187-
("Y", "D3"),
188-
("Z", "X2"),
189-
("Z", "Y"),
190-
],
191-
)
177+
edges = set([
178+
("X1", "X2"),
179+
("X2", "V"),
180+
("X2", "D2"),
181+
("D1", "D2"),
182+
("D1", "Y"),
183+
("Y", "D3"),
184+
("Z", "X2"),
185+
("Z", "Y"),
186+
])
187+
self.assertTrue(
188+
set(proper_backdoor_graph.graph.edges).issubset(edges))
192189

193190
def test_constructive_backdoor_criterion_should_hold(self):
194191
"""Test whether the constructive criterion holds when it should."""
@@ -198,7 +195,7 @@ def test_constructive_backdoor_criterion_should_hold(self):
198195
self.assertTrue(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
199196

200197
def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_proper_backdoor_graph(
201-
self,
198+
self,
202199
):
203200
"""Test whether the constructive criterion fails when the adjustment set is not a d-separator."""
204201
causal_dag = CausalDAG(self.dag_dot_path)
@@ -207,7 +204,7 @@ def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_prop
207204
self.assertFalse(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
208205

209206
def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_causal_path(
210-
self,
207+
self,
211208
):
212209
"""Test whether the constructive criterion holds when the adjustment set Z contains a descendent of a variable
213210
on a proper causal path between X and Y."""
@@ -392,7 +389,6 @@ def tearDown(self) -> None:
392389

393390

394391
class TestUndirectedGraphAlgorithms(unittest.TestCase):
395-
396392
"""
397393
Test the graph algorithms designed for the undirected graph variants of a Causal DAG.
398394
During the identification process, a Causal DAG is converted into several forms of undirected graph which allow for

0 commit comments

Comments
 (0)