Skip to content

Commit c16b785

Browse files
author
Chris Fonnesbeck
committed
Main plots accept/return axes or GridSpec
1 parent cd07815 commit c16b785

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

pymc3/plots.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def traceplot(trace, vars=None, figsize=None,
10-
lines=None, combined=False, grid=True):
10+
lines=None, combined=False, grid=True, ax=None):
1111
"""Plot samples histograms and values
1212
1313
Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
2727
(default), chains will be plotted separately.
2828
grid : bool
2929
Flag for adding gridlines to histogram. Defaults to True.
30+
ax : axes
31+
Matplotlib axes. Defaults to None.
3032
3133
Returns
3234
-------
3335
34-
fig, ax : tuple of matplotlib figure and axes
36+
ax : matplotlib axes
3537
3638
"""
3739
import matplotlib.pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
4345
if figsize is None:
4446
figsize = (12, n*2)
4547

46-
fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
48+
if ax is None:
49+
fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
50+
elif ax.shape != (n,2):
51+
print('traceplot requires n*2 subplots')
52+
return None
4753

4854
for i, v in enumerate(vars):
4955
for d in trace.get_values(v, combine=combined, squeeze=False):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
6975
pass
7076

7177
plt.tight_layout()
72-
return (fig, ax)
78+
return ax
7379

7480
def histplot_op(ax, data):
7581
for i in range(data.shape[1]):
@@ -128,19 +134,21 @@ def kde2plot_op(ax, x, y, grid=200):
128134
extent=[xmin, xmax, ymin, ymax])
129135

130136

131-
def kdeplot(data):
132-
f, ax = subplots(1, 1, squeeze=True)
137+
def kdeplot(data, ax=None):
138+
if ax is None:
139+
f, ax = subplots(1, 1, squeeze=True)
133140
kdeplot_op(ax, data)
134-
return f, ax
141+
return ax
135142

136143

137-
def kde2plot(x, y, grid=200):
138-
f, ax = subplots(1, 1, squeeze=True)
144+
def kde2plot(x, y, grid=200, ax=None):
145+
if ax is None:
146+
f, ax = subplots(1, 1, squeeze=True)
139147
kde2plot_op(ax, x, y, grid)
140-
return f, ax
148+
return ax
141149

142150

143-
def autocorrplot(trace, vars=None, max_lag=100, burn=0):
151+
def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None):
144152
"""Bar plot of the autocorrelation function for a trace
145153
146154
Parameters
@@ -149,16 +157,18 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
149157
trace : result of MCMC run
150158
vars : list of variable names
151159
Variables to be plotted, if None all variable are plotted
152-
max_lag: int
160+
max_lag : int
153161
Maximum lag to calculate autocorrelation. Defaults to 100.
154-
burn: int
162+
burn : int
155163
Number of samples to discard from the beginning of the trace.
156164
Defaults to 0.
157-
165+
ax : axes
166+
Matplotlib axes. Defaults to None.
167+
158168
Returns
159169
-------
160170
161-
fig, ax : tuple of matplotlib figure and axes
171+
ax : matplotlib axes
162172
163173
"""
164174

@@ -213,7 +223,7 @@ def var_str(name, shape):
213223

214224
def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
215225
main=None, xtitle=None, xrange=None, ylabels=None,
216-
chain_spacing=0.05, vline=0):
226+
chain_spacing=0.05, vline=0, gs=None):
217227
""" Forest plot (model summary plot)
218228
219229
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -258,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
258268
259269
vline (optional): numeric
260270
Location of vertical reference line (defaults to 0).
271+
272+
gs : GridSpec
273+
Matplotlib GridSpec object. Defaults to None.
274+
275+
Returns
276+
-------
277+
278+
gs : matplotlib GridSpec
261279
262280
"""
263281
import matplotlib.pyplot as plt
@@ -283,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
283301
# Number of chains
284302
chains = None
285303

286-
# Gridspec
287-
gs = None
288-
289304
# Subplots
290305
interval_plot = None
291306
rhat_plot = None

0 commit comments

Comments
 (0)