Skip to content

Commit 2043dba

Browse files
author
Chris Fonnesbeck
committed
Revert "Plots return (fig, ax) tuple to prevent double-plotting; added docstring to autocorrplot"
This reverts commit bef2ee6.
1 parent bef2ee6 commit 2043dba

File tree

1 file changed

+17
-30
lines changed

1 file changed

+17
-30
lines changed

pymc3/plots.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
3131
Returns
3232
-------
3333
34-
fig, ax : tuple of matplotlib figure and axes
34+
fig : figure object
3535
3636
"""
3737
import matplotlib.pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
6969
pass
7070

7171
plt.tight_layout()
72-
return (fig, ax)
72+
return fig
7373

7474
def histplot_op(ax, data):
7575
for i in range(data.shape[1]):
@@ -131,38 +131,20 @@ def kde2plot_op(ax, x, y, grid=200):
131131
def kdeplot(data):
132132
f, ax = subplots(1, 1, squeeze=True)
133133
kdeplot_op(ax, data)
134-
return f, ax
134+
return f
135135

136136

137137
def kde2plot(x, y, grid=200):
138138
f, ax = subplots(1, 1, squeeze=True)
139139
kde2plot_op(ax, x, y, grid)
140-
return f, ax
140+
return f
141141

142142

143-
def autocorrplot(trace, vars=None, max_lag=100, burn=0):
144-
"""Bar plot of the autocorrelation function for a trace
145-
146-
Parameters
147-
----------
148-
149-
trace : result of MCMC run
150-
vars : list of variable names
151-
Variables to be plotted, if None all variable are plotted
152-
max_lag: int
153-
Maximum lag to calculate autocorrelation. Defaults to 100.
154-
burn: int
155-
Number of samples to discard from the beginning of the trace.
156-
Defaults to 0.
157-
158-
Returns
159-
-------
160-
161-
fig, ax : tuple of matplotlib figure and axes
162-
163-
"""
164-
143+
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
144+
"""Bar plot of the autocorrelation function for a trace"""
165145
import matplotlib.pyplot as plt
146+
if fontmap is None:
147+
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
166148

167149
if vars is None:
168150
vars = trace.varnames
@@ -171,13 +153,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
171153

172154
chains = trace.nchains
173155

174-
fig, ax = plt.subplots(len(vars), chains, squeeze=False)
156+
f, ax = plt.subplots(len(vars), chains, squeeze=False)
175157

176158
max_lag = min(len(trace) - 1, max_lag)
177159

178160
for i, v in enumerate(vars):
179161
for j in range(chains):
180-
d = np.squeeze(trace.get_values(v, chains=[j], burn=burn))
162+
d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin))
181163

182164
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
183165

@@ -187,8 +169,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
187169

188170
if chains > 1:
189171
ax[i, j].set_title("chain {0}".format(j+1))
190-
191-
return (fig, ax)
172+
173+
# Smaller tick labels
174+
tlabels = plt.gca().get_xticklabels()
175+
plt.setp(tlabels, 'fontsize', fontmap[1])
176+
177+
tlabels = plt.gca().get_yticklabels()
178+
plt.setp(tlabels, 'fontsize', fontmap[1])
192179

193180

194181
def var_str(name, shape):

0 commit comments

Comments
 (0)