@@ -80,7 +80,7 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
80
80
return ax
81
81
82
82
83
- def _plot_tree (ax , y , ntiles , show_quartiles , ** plot_kwargs ):
83
+ def _plot_tree (ax , y , ntiles , show_quartiles , plot_kwargs ):
84
84
"""Helper to plot errorbars for the forestplot.
85
85
86
86
Parameters
@@ -123,10 +123,10 @@ def _plot_tree(ax, y, ntiles, show_quartiles, **plot_kwargs):
123
123
return ax
124
124
125
125
126
- def forestplot (trace_obj , varnames = None , transform = identity_transform , alpha = 0.05 , quartiles = True ,
127
- rhat = True , main = None , xtitle = None , xlim = None , ylabels = None ,
128
- chain_spacing = 0.05 , vline = 0 , gs = None , plot_transformed = False ,
129
- ** plot_kwargs ):
126
+ def forestplot (trace_obj , varnames = None , transform = identity_transform ,
127
+ alpha = 0.05 , quartiles = True , rhat = True , main = None , xtitle = None ,
128
+ xlim = None , ylabels = None , chain_spacing = 0.05 , vline = 0 , gs = None ,
129
+ plot_transformed = False , plot_kwargs = None ):
130
130
"""
131
131
Forest plot (model summary plot).
132
132
@@ -180,6 +180,9 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
180
180
gs : matplotlib GridSpec
181
181
182
182
"""
183
+ if plot_kwargs is None :
184
+ plot_kwargs = {}
185
+
183
186
# Quantiles to be calculated
184
187
if quartiles :
185
188
qlist = [100 * alpha / 2 , 25 , 50 , 75 , 100 * (1 - alpha / 2 )]
@@ -209,7 +212,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
209
212
# Subplot for confidence intervals
210
213
interval_plot = plt .subplot (gs [0 ])
211
214
212
- trace_quantiles = quantiles (trace_obj , qlist , transform = transform , squeeze = False )
215
+ trace_quantiles = quantiles (trace_obj , qlist , transform = transform ,
216
+ squeeze = False )
213
217
hpd_intervals = hpd (trace_obj , alpha , transform = transform , squeeze = False )
214
218
215
219
labels = []
@@ -246,7 +250,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
246
250
labels .append (varname )
247
251
248
252
# Add spacing for each chain, if more than one
249
- offset = [0 ] + [(chain_spacing * ((i + 2 ) / 2 )) * (- 1 ) ** i for i in range (nchains - 1 )]
253
+ offset = [0 ] + [(chain_spacing * ((i + 2 ) / 2 )) *
254
+ (- 1 ) ** i for i in range (nchains - 1 )]
250
255
251
256
# Y coordinate with offset
252
257
y = - var + offset [j ]
@@ -255,10 +260,12 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
255
260
if k > 1 :
256
261
for q in np .transpose (quants ).squeeze ():
257
262
# Multiple y values
258
- interval_plot = _plot_tree (interval_plot , y , q , quartiles , ** plot_kwargs )
263
+ interval_plot = _plot_tree (interval_plot , y , q , quartiles ,
264
+ plot_kwargs )
259
265
y -= 1
260
266
else :
261
- interval_plot = _plot_tree (interval_plot , y , quants , quartiles , ** plot_kwargs )
267
+ interval_plot = _plot_tree (interval_plot , y , quants , quartiles ,
268
+ plot_kwargs )
262
269
263
270
# Increment index
264
271
var += k
@@ -273,11 +280,13 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
273
280
interval_plot .set_ylim (- var + 0.5 , 0.5 )
274
281
275
282
datarange = plotrange [1 ] - plotrange [0 ]
276
- interval_plot .set_xlim (plotrange [0 ] - 0.05 * datarange , plotrange [1 ] + 0.05 * datarange )
283
+ interval_plot .set_xlim (plotrange [0 ] - 0.05 * datarange ,
284
+ plotrange [1 ] + 0.05 * datarange )
277
285
278
286
# Add variable labels
279
287
interval_plot .set_yticks ([- l for l in range (len (labels ))])
280
- interval_plot .set_yticklabels (labels , fontsize = plot_kwargs .get ('fontsize' , None ))
288
+ interval_plot .set_yticklabels (labels ,
289
+ fontsize = plot_kwargs .get ('fontsize' , None ))
281
290
282
291
# Add title
283
292
plot_title = ""
@@ -286,7 +295,8 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
286
295
elif main :
287
296
plot_title = main
288
297
if plot_title :
289
- interval_plot .set_title (plot_title , fontsize = plot_kwargs .get ('fontsize' , None ))
298
+ interval_plot .set_title (plot_title ,
299
+ fontsize = plot_kwargs .get ('fontsize' , None ))
290
300
291
301
# Add x-axis label
292
302
if xtitle is not None :
@@ -310,6 +320,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
310
320
311
321
# Genenerate Gelman-Rubin plot
312
322
if plot_rhat :
313
- _make_rhat_plot (trace_obj , plt .subplot (gs [1 ]), "R-hat" , labels , varnames , plot_transformed )
323
+ _make_rhat_plot (trace_obj , plt .subplot (gs [1 ]), "R-hat" , labels ,
324
+ varnames , plot_transformed )
314
325
315
326
return gs
0 commit comments