@@ -101,25 +101,25 @@ def _plot_tree(ax, y, ntiles, show_quartiles, **plot_kwargs):
101
101
"""
102
102
if show_quartiles :
103
103
# Plot median
104
- ax .plot (ntiles [2 ], y , color = plot_kwargs .get ('color' , 'blue' ),
105
- marker = plot_kwargs .get ('marker' , 'o' ),
104
+ ax .plot (ntiles [2 ], y , color = plot_kwargs .get ('color' , 'blue' ),
105
+ marker = plot_kwargs .get ('marker' , 'o' ),
106
106
markersize = plot_kwargs .get ('markersize' , 4 ))
107
107
# Plot quartile interval
108
- ax .errorbar (x = (ntiles [1 ], ntiles [3 ]), y = (y , y ),
109
- linewidth = plot_kwargs .get ('linewidth' , 2 ),
108
+ ax .errorbar (x = (ntiles [1 ], ntiles [3 ]), y = (y , y ),
109
+ linewidth = plot_kwargs .get ('linewidth' , 2 ),
110
110
color = plot_kwargs .get ('color' , 'blue' ))
111
111
112
112
else :
113
113
# Plot median
114
- ax .plot (ntiles [1 ], y , marker = plot_kwargs .get ('marker' , 'o' ),
114
+ ax .plot (ntiles [1 ], y , marker = plot_kwargs .get ('marker' , 'o' ),
115
115
color = plot_kwargs .get ('color' , 'blue' ),
116
116
markersize = plot_kwargs .get ('markersize' , 4 ))
117
117
118
118
# Plot outer interval
119
- ax .errorbar (x = (ntiles [0 ], ntiles [- 1 ]), y = (y , y ),
120
- linewidth = int (plot_kwargs .get ('linewidth' , 2 )/ 2 ),
119
+ ax .errorbar (x = (ntiles [0 ], ntiles [- 1 ]), y = (y , y ),
120
+ linewidth = int (plot_kwargs .get ('linewidth' , 2 )/ 2 ),
121
121
color = plot_kwargs .get ('color' , 'blue' ))
122
-
122
+
123
123
return ax
124
124
125
125
@@ -227,11 +227,11 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
227
227
quants [- 1 ] = var_hpd [1 ].T
228
228
229
229
# Ensure x-axis contains range of current interval
230
- if plotrange :
230
+ if plotrange is None :
231
+ plotrange = [np .min (quants ), np .max (quants )]
232
+ else :
231
233
plotrange = [min (plotrange [0 ], np .min (quants )),
232
234
max (plotrange [1 ], np .max (quants ))]
233
- else :
234
- plotrange = [np .min (quants ), np .max (quants )]
235
235
236
236
# Number of elements in current variable
237
237
value = trace_obj .get_values (varname , chains = [chain ])[0 ]
@@ -255,12 +255,10 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
255
255
if k > 1 :
256
256
for q in np .transpose (quants ).squeeze ():
257
257
# Multiple y values
258
- interval_plot = _plot_tree (interval_plot , y , q , quartiles ,
259
- ** plot_kwargs )
258
+ interval_plot = _plot_tree (interval_plot , y , q , quartiles , ** plot_kwargs )
260
259
y -= 1
261
260
else :
262
- interval_plot = _plot_tree (interval_plot , y , quants , quartiles ,
263
- ** plot_kwargs )
261
+ interval_plot = _plot_tree (interval_plot , y , quants , quartiles , ** plot_kwargs )
264
262
265
263
# Increment index
266
264
var += k
0 commit comments