Skip to content

Commit 9883915

Browse files
jonititanOriolAbrilJoni Pelham
authored
Added networkx export functionality (#6046)
* Added networkx export functionality Apologies for delay but finally got around to updating the latest version of model_graph.py to include networkx export function as discussed in #5677 * Update model_graph.py restoring whitespace #5677 (comment) * Corrected cluster(subgraph) behaviour subgraph(cluster) attributes are now stored on each node that is a member of that subgraph(cluster). * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Correcting docstring spacing * Updated to include new function model_to_networkx * linted and checked by pre-commit * added test for model_to_networkx function * added function import * added model_to_networkx * corrected formatting * more linting * Update test_model_graph.py * Update __init__.py * Redid changes * added function to models.rst * corrected indent issue indent issue cause by merge resolution * corrected spelling error * added networkx to requirements * redid pre-commit and ran it twice * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Update pymc/model_graph.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Removed trailing whitespace Removed trailing whitespace from lines 312, 314, * removing more trailing whitespace Co-authored-by: Oriol Abril-Pla <[email protected]> Co-authored-by: Joni Pelham <[email protected]>
1 parent 906fcdc commit 9883915

File tree

9 files changed

+200
-8
lines changed

9 files changed

+200
-8
lines changed

conda-envs/environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- pandas>=0.24.0
1818
- pip
1919
- python-graphviz
20+
- networkx
2021
- scipy>=1.4.1
2122
- typing-extensions>=3.7.4
2223
# Extra dependencies for dev, testing and docs build

conda-envs/environment-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- numpy>=1.15.0
2020
- pandas>=0.24.0
2121
- python-graphviz
22+
- networkx
2223
- scipy>=1.4.1
2324
- typing-extensions>=3.7.4
2425
# Extra dependencies for testing

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- pandas>=0.24.0
1818
- pip
1919
- python-graphviz
20+
- networkx
2021
- scipy>=1.4.1
2122
- typing-extensions>=3.7.4
2223
# Extra dependencies for dev, testing and docs build

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies:
2020
- pandas>=0.24.0
2121
- pip
2222
- python-graphviz
23+
- networkx
2324
- scipy>=1.4.1
2425
- typing-extensions>=3.7.4
2526
# Extra dependencies for testing

docs/source/api/model.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Model creation and inspection
1010

1111
Model
1212
model_to_graphviz
13+
model_to_networkx
1314
modelcontext
1415

1516
Others

pymc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __set_compiler_flags():
6565
probit,
6666
)
6767
from pymc.model import *
68-
from pymc.model_graph import model_to_graphviz
68+
from pymc.model_graph import model_to_graphviz, model_to_networkx
6969
from pymc.plots import *
7070
from pymc.printing import *
7171
from pymc.sampling import *

pymc/model_graph.py

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def make_compute_graph(
125125

126126
return input_map
127127

128-
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
129-
"""Attaches the given variable to a graphviz Digraph"""
128+
def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"):
129+
"""Attaches the given variable to a graphviz or networkx Digraph"""
130130
v = self.model[var_name]
131131

132132
shape = None
@@ -168,7 +168,13 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
168168
"label": label,
169169
}
170170

171-
graph.node(var_name.replace(":", "&"), **kwargs)
171+
if cluster:
172+
kwargs["cluster"] = cluster
173+
174+
if nx:
175+
graph.add_node(var_name.replace(":", "&"), **kwargs)
176+
else:
177+
graph.node(var_name.replace(":", "&"), **kwargs)
172178

173179
def _eval(self, var):
174180
return function([], var, mode="FAST_COMPILE")()
@@ -178,7 +184,6 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
178184
179185
Just groups by the shape of the underlying distribution. Will be wrong
180186
if there are two plates with the same shape.
181-
182187
Returns
183188
-------
184189
dict
@@ -234,9 +239,134 @@ def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting:
234239

235240
return graph
236241

242+
def make_networkx(
243+
self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
244+
):
245+
"""Make networkx Digraph of PyMC model
246+
247+
Returns
248+
-------
249+
networkx.Digraph
250+
"""
251+
try:
252+
import networkx
253+
except ImportError:
254+
raise ImportError(
255+
"This function requires the python library networkx, along with binaries. "
256+
"The easiest way to install all of this is by running\n\n"
257+
"\tconda install networkx"
258+
)
259+
graphnetwork = networkx.DiGraph(name=self.model.name)
260+
for plate_label, all_var_names in self.get_plates(var_names).items():
261+
if plate_label:
262+
# # must be preceded by 'cluster' to get a box around it
263+
264+
subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label)
265+
266+
for var_name in all_var_names:
267+
self._make_node(
268+
var_name,
269+
subgraphnetwork,
270+
nx=True,
271+
cluster="cluster" + plate_label,
272+
formatting=formatting,
273+
)
274+
for sgn in subgraphnetwork.nodes:
275+
networkx.set_node_attributes(
276+
subgraphnetwork,
277+
{sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}},
278+
)
279+
node_data = {
280+
e[0]: e[1]
281+
for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True)
282+
}
283+
284+
graphnetwork = networkx.compose(graphnetwork, subgraphnetwork)
285+
networkx.set_node_attributes(graphnetwork, node_data)
286+
graphnetwork.graph["name"] = self.model.name
287+
else:
288+
for var_name in all_var_names:
289+
290+
self._make_node(var_name, graphnetwork, nx=True, formatting=formatting)
291+
292+
for child, parents in self.make_compute_graph(var_names=var_names).items():
293+
# parents is a set of rv names that preceed child rv nodes
294+
for parent in parents:
295+
graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&"))
296+
return graphnetwork
297+
298+
299+
def model_to_networkx(
300+
model=None,
301+
*,
302+
var_names: Optional[Iterable[VarName]] = None,
303+
formatting: str = "plain",
304+
):
305+
"""Produce a networkx Digraph from a PyMC model.
306+
307+
Requires networkx, which may be installed most easily with::
308+
309+
conda install networkx
310+
311+
Alternatively, you may install using pip with::
312+
313+
pip install networkx
314+
315+
See https://networkx.org/documentation/stable/ for more information.
316+
317+
Parameters
318+
----------
319+
model : Model
320+
The model to plot. Not required when called from inside a modelcontext.
321+
var_names : iterable of str, optional
322+
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
323+
formatting : str, optional
324+
one of { "plain" }
325+
326+
Examples
327+
--------
328+
How to plot the graph of the model.
329+
330+
.. code-block:: python
331+
332+
import numpy as np
333+
from pymc import HalfCauchy, Model, Normal, model_to_networkx
334+
335+
J = 8
336+
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
337+
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
338+
339+
with Model() as schools:
340+
341+
eta = Normal("eta", 0, 1, shape=J)
342+
mu = Normal("mu", 0, sigma=1e6)
343+
tau = HalfCauchy("tau", 25)
344+
345+
theta = mu + tau * eta
346+
347+
obs = Normal("obs", theta, sigma=sigma, observed=y)
348+
349+
model_to_networkx(schools)
350+
"""
351+
if not "plain" in formatting:
352+
353+
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
354+
355+
if formatting != "plain":
356+
warnings.warn(
357+
"Formattings other than 'plain' are currently not supported.",
358+
UserWarning,
359+
stacklevel=2,
360+
)
361+
model = pm.modelcontext(model)
362+
return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting)
363+
237364

238365
def model_to_graphviz(
239-
model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
366+
model=None,
367+
*,
368+
var_names: Optional[Iterable[VarName]] = None,
369+
formatting: str = "plain",
240370
):
241371
"""Produce a graphviz Digraph from a PyMC model.
242372
@@ -286,7 +416,9 @@ def model_to_graphviz(
286416
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
287417
if formatting != "plain":
288418
warnings.warn(
289-
"Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2
419+
"Formattings other than 'plain' are currently not supported.",
420+
UserWarning,
421+
stacklevel=2,
290422
)
291423
model = pm.modelcontext(model)
292424
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)

pymc/tests/test_model_graph.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,64 @@
2020

2121
import pymc as pm
2222

23-
from pymc.model_graph import ModelGraph, model_to_graphviz
23+
from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx
2424
from pymc.tests.helpers import SeededTest
2525

2626

27+
def school_model():
28+
"""
29+
Schools model to use in testing model_to_networkx function
30+
"""
31+
J = 8
32+
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
33+
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
34+
with pm.Model() as schools:
35+
eta = pm.Normal("eta", 0, 1, shape=J)
36+
mu = pm.Normal("mu", 0, sigma=1e6)
37+
tau = pm.HalfCauchy("tau", 25)
38+
theta = mu + tau * eta
39+
obs = pm.Normal("obs", theta, sigma=sigma, observed=y)
40+
return schools
41+
42+
43+
class BaseModelNXTest(SeededTest):
44+
network_model = {
45+
"graph_attr_dict_factory": dict,
46+
"node_dict_factory": dict,
47+
"node_attr_dict_factory": dict,
48+
"adjlist_outer_dict_factory": dict,
49+
"adjlist_inner_dict_factory": dict,
50+
"edge_attr_dict_factory": dict,
51+
"graph": {"name": "", "label": "8"},
52+
"_node": {
53+
"eta": {
54+
"shape": "ellipse",
55+
"style": "rounded",
56+
"label": "eta\n~\nNormal",
57+
"cluster": "cluster8",
58+
"labeljust": "r",
59+
"labelloc": "b",
60+
},
61+
"obs": {
62+
"shape": "ellipse",
63+
"style": "rounded",
64+
"label": "obs\n~\nNormal",
65+
"cluster": "cluster8",
66+
"labeljust": "r",
67+
"labelloc": "b",
68+
},
69+
"tau": {"shape": "ellipse", "style": None, "label": "tau\n~\nHalfCauchy"},
70+
"mu": {"shape": "ellipse", "style": None, "label": "mu\n~\nNormal"},
71+
},
72+
"_adj": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}},
73+
"_pred": {"eta": {}, "obs": {"tau": {}, "eta": {}, "mu": {}}, "tau": {}, "mu": {}},
74+
"_succ": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}},
75+
}
76+
77+
def test_networkx(self):
78+
assert self.network_model == model_to_networkx(school_model()).__dict__
79+
80+
2781
def radon_model():
2882
"""Similar in shape to the Radon model"""
2983
n_homes = 919

scripts/generate_pip_deps_from_conda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"mkl-service",
5252
"numba",
5353
"python-graphviz",
54+
"networkx",
5455
"blas",
5556
"jax",
5657
}

0 commit comments

Comments
 (0)