Skip to content

Commit bef2ee6

Browse files
author
Chris Fonnesbeck
committed
Plots return (fig, ax) tuple to prevent double-plotting; added docstring to autocorrplot
1 parent 96e03fc commit bef2ee6

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

pymc3/plots.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
3131
Returns
3232
-------
3333
34-
fig : figure object
34+
fig, ax : tuple of matplotlib figure and axes
3535
3636
"""
3737
import matplotlib.pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
6969
pass
7070

7171
plt.tight_layout()
72-
return fig
72+
return (fig, ax)
7373

7474
def histplot_op(ax, data):
7575
for i in range(data.shape[1]):
@@ -131,20 +131,38 @@ def kde2plot_op(ax, x, y, grid=200):
131131
def kdeplot(data):
132132
f, ax = subplots(1, 1, squeeze=True)
133133
kdeplot_op(ax, data)
134-
return f
134+
return f, ax
135135

136136

137137
def kde2plot(x, y, grid=200):
138138
f, ax = subplots(1, 1, squeeze=True)
139139
kde2plot_op(ax, x, y, grid)
140-
return f
140+
return f, ax
141141

142142

143-
def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
144-
"""Bar plot of the autocorrelation function for a trace"""
143+
def autocorrplot(trace, vars=None, max_lag=100, burn=0):
144+
"""Bar plot of the autocorrelation function for a trace
145+
146+
Parameters
147+
----------
148+
149+
trace : result of MCMC run
150+
vars : list of variable names
151+
Variables to be plotted, if None all variable are plotted
152+
max_lag: int
153+
Maximum lag to calculate autocorrelation. Defaults to 100.
154+
burn: int
155+
Number of samples to discard from the beginning of the trace.
156+
Defaults to 0.
157+
158+
Returns
159+
-------
160+
161+
fig, ax : tuple of matplotlib figure and axes
162+
163+
"""
164+
145165
import matplotlib.pyplot as plt
146-
if fontmap is None:
147-
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
148166

149167
if vars is None:
150168
vars = trace.varnames
@@ -153,13 +171,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153171

154172
chains = trace.nchains
155173

156-
f, ax = plt.subplots(len(vars), chains, squeeze=False)
174+
fig, ax = plt.subplots(len(vars), chains, squeeze=False)
157175

158176
max_lag = min(len(trace) - 1, max_lag)
159177

160178
for i, v in enumerate(vars):
161179
for j in range(chains):
162-
d = np.squeeze(trace.get_values(v, chains=[j],burn=burn,thin=thin))
180+
d = np.squeeze(trace.get_values(v, chains=[j], burn=burn))
163181

164182
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
165183

@@ -169,13 +187,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169187

170188
if chains > 1:
171189
ax[i, j].set_title("chain {0}".format(j+1))
172-
173-
# Smaller tick labels
174-
tlabels = plt.gca().get_xticklabels()
175-
plt.setp(tlabels, 'fontsize', fontmap[1])
176-
177-
tlabels = plt.gca().get_yticklabels()
178-
plt.setp(tlabels, 'fontsize', fontmap[1])
190+
191+
return (fig, ax)
179192

180193

181194
def var_str(name, shape):

0 commit comments

Comments
 (0)