Skip to content

Commit 51f196d

Browse files
ColCarrolltwiecki
authored andcommitted
Accept extent as kwarg to kde_plot (#1509)
1 parent 1e8549e commit 51f196d

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

pymc3/plots.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def kde2plot_op(ax, x, y, grid=200, **kwargs):
156156
xmax = x.max()
157157
ymin = y.min()
158158
ymax = y.max()
159+
extent = kwargs.pop('extent', [])
160+
if len(extent) != 4:
161+
extent = [xmin, xmax, ymin, ymax]
159162

160163
grid = grid * 1j
161164
X, Y = np.mgrid[xmin:xmax:grid, ymin:ymax:grid]
@@ -164,7 +167,7 @@ def kde2plot_op(ax, x, y, grid=200, **kwargs):
164167
kernel = kde.gaussian_kde(values)
165168
Z = np.reshape(kernel(positions).T, X.shape)
166169

167-
ax.imshow(np.rot90(Z), extent=[xmin, xmax, ymin, ymax], **kwargs)
170+
ax.imshow(np.rot90(Z), extent=extent, **kwargs)
168171

169172

170173
def kdeplot(data, ax=None):
@@ -363,9 +366,6 @@ def forestplot(trace_obj, varnames=None, transform=lambda x: x, alpha=0.05, quar
363366
# Range for x-axis
364367
plotrange = None
365368

366-
# Number of chains
367-
chains = None
368-
369369
# Subplots
370370
interval_plot = None
371371
rhat_plot = None
@@ -648,26 +648,22 @@ def plot_posterior(trace, varnames=None, transform=lambda x: x, figsize=None,
648648
"""
649649

650650
def plot_posterior_op(trace_values, ax):
651-
652651
def format_as_percent(x, round_to=0):
653-
value = np.round(100 * x, round_to)
654-
if round_to == 0:
655-
value = int(value)
656-
return '{}%'.format(value)
652+
return '{0:.{1:d}f}%'.format(100 * x, round_to)
657653

658654
def display_ref_val(ref_val):
659655
less_than_ref_probability = (trace_values < ref_val).mean()
660656
greater_than_ref_probability = (trace_values >= ref_val).mean()
661-
ref_in_posterior = format_as_percent(less_than_ref_probability, 1) + ' <{:g}< '.format(ref_val) + format_as_percent(
662-
greater_than_ref_probability, 1)
657+
ref_in_posterior = "{} <{:g}< {}".format(
658+
format_as_percent(less_than_ref_probability, 1),
659+
ref_val,
660+
format_as_percent(greater_than_ref_probability, 1))
663661
ax.axvline(ref_val, ymin=0.02, ymax=.75, color='g',
664662
linewidth=4, alpha=0.65)
665663
ax.text(trace_values.mean(), plot_height * 0.6, ref_in_posterior,
666664
size=14, horizontalalignment='center')
667665

668666
def display_rope(rope):
669-
pc_in_rope = format_as_percent(np.sum((trace_values > rope[0]) &
670-
(trace_values < rope[1])) / len(trace_values), round_to)
671667
ax.plot(rope, (plot_height * 0.02, plot_height * 0.02),
672668
linewidth=20, color='r', alpha=0.75)
673669
text_props = dict(size=16, horizontalalignment='center', color='r')
@@ -756,15 +752,15 @@ def create_axes_grid(figsize, traces):
756752
ax[-1].set_axis_off()
757753
ax = ax[:-1]
758754
return ax, fig
759-
755+
760756
def get_trace_dict(tr, varnames):
761757
traces = {}
762758
for v in varnames:
763759
vals = tr.get_values(v, combine=True, squeeze=True)
764-
if vals.ndim>1:
760+
if vals.ndim > 1:
765761
vals_flat = vals.reshape(vals.shape[0], -1).T
766-
for i,vi in enumerate(vals_flat):
767-
traces['_'.join([v,str(i)])] = vi
762+
for i, vi in enumerate(vals_flat):
763+
traces['_'.join([v, str(i)])] = vi
768764
else:
769765
traces[v] = vals
770766
return traces

0 commit comments

Comments
 (0)