Skip to content

Commit 988f522

Browse files
sidravi1ColCarroll
authored andcommitted
Fix for #3483: model_to_graphviz now works with shared variables (#3490)
* fix for 3483 + modified test_model_grah.py to test it * fixed model_graph.py to handle shardvars * adding import for theano
1 parent 6a6da63 commit 988f522

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

pymc3/model_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def make_graph(self):
161161
'\tconda install -c conda-forge python-graphviz')
162162
graph = graphviz.Digraph(self.model.name)
163163
for shape, var_names in self.get_plates().items():
164+
if isinstance(shape, SharedVariable):
165+
shape = shape.eval()
164166
label = ' x '.join(map('{:,d}'.format, shape))
165167
if label:
166168
# must be preceded by 'cluster' to get a box around it

pymc3/tests/test_model_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pymc3 as pm
3+
import theano as th
34
from pymc3.model_graph import ModelGraph, model_to_graphviz
4-
55
from .helpers import SeededTest
66

77

@@ -14,6 +14,8 @@ def radon_model():
1414
floor_measure = np.random.randint(0, 2, size=n_homes)
1515
log_radon = np.random.normal(1, 1, size=n_homes)
1616

17+
floor_measure = th.shared(floor_measure)
18+
1719
d, r = divmod(919, 85)
1820
county = np.hstack((
1921
np.tile(np.arange(counties, dtype=int), d),

0 commit comments

Comments
 (0)