Skip to content

Commit 1785ec1

Browse files
author
Chris Fonnesbeck
committed
Merge pull request #745 from pymc-devs/double_plot_fix
Fix for double-plotting of some output
2 parents 2043dba + c16b785 commit 1785ec1

File tree

1 file changed

+55
-27
lines changed

1 file changed

+55
-27
lines changed

pymc3/plots.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def traceplot(trace, vars=None, figsize=None,
10-
lines=None, combined=False, grid=True):
10+
lines=None, combined=False, grid=True, ax=None):
1111
"""Plot samples histograms and values
1212
1313
Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
2727
(default), chains will be plotted separately.
2828
grid : bool
2929
Flag for adding gridlines to histogram. Defaults to True.
30+
ax : axes
31+
Matplotlib axes. Defaults to None.
3032
3133
Returns
3234
-------
3335
34-
fig : figure object
36+
ax : matplotlib axes
3537
3638
"""
3739
import matplotlib.pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
4345
if figsize is None:
4446
figsize = (12, n*2)
4547

46-
fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
48+
if ax is None:
49+
fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
50+
elif ax.shape != (n,2):
51+
print('traceplot requires n*2 subplots')
52+
return None
4753

4854
for i, v in enumerate(vars):
4955
for d in trace.get_values(v, combine=combined, squeeze=False):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
6975
pass
7076

7177
plt.tight_layout()
72-
return fig
78+
return ax
7379

7480
def histplot_op(ax, data):
7581
for i in range(data.shape[1]):
@@ -128,23 +134,45 @@ def kde2plot_op(ax, x, y, grid=200):
128134
extent=[xmin, xmax, ymin, ymax])
129135

130136

131-
def kdeplot(data):
132-
f, ax = subplots(1, 1, squeeze=True)
137+
def kdeplot(data, ax=None):
138+
if ax is None:
139+
f, ax = subplots(1, 1, squeeze=True)
133140
kdeplot_op(ax, data)
134-
return f
141+
return ax
135142

136143

137-
def kde2plot(x, y, grid=200):
138-
f, ax = subplots(1, 1, squeeze=True)
144+
def kde2plot(x, y, grid=200, ax=None):
145+
if ax is None:
146+
f, ax = subplots(1, 1, squeeze=True)
139147
kde2plot_op(ax, x, y, grid)
140-
return f
148+
return ax
141149

142150

143-
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
144-
"""Bar plot of the autocorrelation function for a trace"""
151+
def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None):
152+
"""Bar plot of the autocorrelation function for a trace
153+
154+
Parameters
155+
----------
156+
157+
trace : result of MCMC run
158+
vars : list of variable names
159+
Variables to be plotted, if None all variable are plotted
160+
max_lag : int
161+
Maximum lag to calculate autocorrelation. Defaults to 100.
162+
burn : int
163+
Number of samples to discard from the beginning of the trace.
164+
Defaults to 0.
165+
ax : axes
166+
Matplotlib axes. Defaults to None.
167+
168+
Returns
169+
-------
170+
171+
ax : matplotlib axes
172+
173+
"""
174+
145175
import matplotlib.pyplot as plt
146-
if fontmap is None:
147-
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
148176

149177
if vars is None:
150178
vars = trace.varnames
@@ -153,13 +181,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153181

154182
chains = trace.nchains
155183

156-
f, ax = plt.subplots(len(vars), chains, squeeze=False)
184+
fig, ax = plt.subplots(len(vars), chains, squeeze=False)
157185

158186
max_lag = min(len(trace) - 1, max_lag)
159187

160188
for i, v in enumerate(vars):
161189
for j in range(chains):
162-
d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin))
190+
d = np.squeeze(trace.get_values(v, chains=[j], burn=burn))
163191

164192
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
165193

@@ -169,13 +197,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169197

170198
if chains > 1:
171199
ax[i, j].set_title("chain {0}".format(j+1))
172-
173-
# Smaller tick labels
174-
tlabels = plt.gca().get_xticklabels()
175-
plt.setp(tlabels, 'fontsize', fontmap[1])
176-
177-
tlabels = plt.gca().get_yticklabels()
178-
plt.setp(tlabels, 'fontsize', fontmap[1])
200+
201+
return (fig, ax)
179202

180203

181204
def var_str(name, shape):
@@ -200,7 +223,7 @@ def var_str(name, shape):
200223

201224
def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
202225
main=None, xtitle=None, xrange=None, ylabels=None,
203-
chain_spacing=0.05, vline=0):
226+
chain_spacing=0.05, vline=0, gs=None):
204227
""" Forest plot (model summary plot)
205228
206229
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -245,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
245268
246269
vline (optional): numeric
247270
Location of vertical reference line (defaults to 0).
271+
272+
gs : GridSpec
273+
Matplotlib GridSpec object. Defaults to None.
274+
275+
Returns
276+
-------
277+
278+
gs : matplotlib GridSpec
248279
249280
"""
250281
import matplotlib.pyplot as plt
@@ -270,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
270301
# Number of chains
271302
chains = None
272303

273-
# Gridspec
274-
gs = None
275-
276304
# Subplots
277305
interval_plot = None
278306
rhat_plot = None

0 commit comments

Comments
 (0)