Skip to content

Commit 177fa08

Browse files
aloctavodiatwiecki
authored andcommitted
fix bug, forestplot was silently not accepting plot_kwargs
1 parent 1a58aa3 commit 177fa08

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

pymc3/plots/forestplot.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
8080
return ax
8181

8282

83-
def _plot_tree(ax, y, ntiles, show_quartiles, **plot_kwargs):
83+
def _plot_tree(ax, y, ntiles, show_quartiles, plot_kwargs):
8484
"""Helper to plot errorbars for the forestplot.
8585
8686
Parameters
@@ -123,10 +123,10 @@ def _plot_tree(ax, y, ntiles, show_quartiles, **plot_kwargs):
123123
return ax
124124

125125

126-
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True,
127-
rhat=True, main=None, xtitle=None, xlim=None, ylabels=None,
128-
chain_spacing=0.05, vline=0, gs=None, plot_transformed=False,
129-
**plot_kwargs):
126+
def forestplot(trace_obj, varnames=None, transform=identity_transform,
127+
alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None,
128+
xlim=None, ylabels=None, chain_spacing=0.05, vline=0, gs=None,
129+
plot_transformed=False, plot_kwargs=None):
130130
"""
131131
Forest plot (model summary plot).
132132
@@ -180,6 +180,9 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
180180
gs : matplotlib GridSpec
181181
182182
"""
183+
if plot_kwargs is None:
184+
plot_kwargs = {}
185+
183186
# Quantiles to be calculated
184187
if quartiles:
185188
qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
@@ -209,7 +212,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
209212
# Subplot for confidence intervals
210213
interval_plot = plt.subplot(gs[0])
211214

212-
trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False)
215+
trace_quantiles = quantiles(trace_obj, qlist, transform=transform,
216+
squeeze=False)
213217
hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)
214218

215219
labels = []
@@ -246,7 +250,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
246250
labels.append(varname)
247251

248252
# Add spacing for each chain, if more than one
249-
offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains - 1)]
253+
offset = [0] + [(chain_spacing * ((i + 2) / 2)) *
254+
(-1) ** i for i in range(nchains - 1)]
250255

251256
# Y coordinate with offset
252257
y = -var + offset[j]
@@ -255,10 +260,12 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
255260
if k > 1:
256261
for q in np.transpose(quants).squeeze():
257262
# Multiple y values
258-
interval_plot = _plot_tree(interval_plot, y, q, quartiles, **plot_kwargs)
263+
interval_plot = _plot_tree(interval_plot, y, q, quartiles,
264+
plot_kwargs)
259265
y -= 1
260266
else:
261-
interval_plot = _plot_tree(interval_plot, y, quants, quartiles, **plot_kwargs)
267+
interval_plot = _plot_tree(interval_plot, y, quants, quartiles,
268+
plot_kwargs)
262269

263270
# Increment index
264271
var += k
@@ -273,11 +280,13 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
273280
interval_plot.set_ylim(-var + 0.5, 0.5)
274281

275282
datarange = plotrange[1] - plotrange[0]
276-
interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)
283+
interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
284+
plotrange[1] + 0.05 * datarange)
277285

278286
# Add variable labels
279287
interval_plot.set_yticks([-l for l in range(len(labels))])
280-
interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None))
288+
interval_plot.set_yticklabels(labels,
289+
fontsize=plot_kwargs.get('fontsize', None))
281290

282291
# Add title
283292
plot_title = ""
@@ -286,7 +295,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
286295
elif main:
287296
plot_title = main
288297
if plot_title:
289-
interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None))
298+
interval_plot.set_title(plot_title,
299+
fontsize=plot_kwargs.get('fontsize', None))
290300

291301
# Add x-axis label
292302
if xtitle is not None:
@@ -310,6 +320,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
310320

311321
# Genenerate Gelman-Rubin plot
312322
if plot_rhat:
313-
_make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed)
323+
_make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels,
324+
varnames, plot_transformed)
314325

315326
return gs

0 commit comments

Comments
 (0)