Skip to content

Commit d2528b8

Browse files
aloctavodiaspringcoil
authored andcommitted
show transformed variables when requested by the user (#1386)
* show transformed when requested by the user * add tests plot_transformed and include_transformed
1 parent df0a64b commit d2528b8

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

pymc3/plots.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def traceplot(trace, varnames=None, transform=lambda x: x, figsize=None,
5656
"""
5757

5858
if varnames is None:
59-
varnames = [name for name in trace.varnames if not name.endswith('_')]
59+
if plot_transformed:
60+
varnames = [name for name in trace.varnames]
61+
else:
62+
varnames = [name for name in trace.varnames if not name.endswith('_')]
6063

6164
n = len(varnames)
6265

@@ -217,7 +220,10 @@ def _handle_array_varnames(varname):
217220
yield varname
218221

219222
if varnames is None:
220-
varnames = [name for name in trace.varnames if not name.endswith('_')]
223+
if plot_transformed:
224+
varnames = [name for name in trace.varnames]
225+
else:
226+
varnames = [name for name in trace.varnames if not name.endswith('_')]
221227

222228
varnames = [item for sub in [[i for i in _handle_array_varnames(v)]
223229
for v in varnames] for item in sub]
@@ -758,8 +764,10 @@ def create_axes_grid(figsize, varnames):
758764
plot_posterior_op(transform(trace), ax)
759765
else:
760766
if varnames is None:
761-
varnames = [
762-
name for name in trace.varnames if not name.endswith('_')]
767+
if plot_transformed:
768+
varnames = [name for name in trace.varnames]
769+
else:
770+
varnames = [name for name in trace.varnames if not name.endswith('_')]
763771

764772
if ax is None:
765773
ax, fig = create_axes_grid(figsize, varnames)

pymc3/stats.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_tran
485485
mu__1 0.067513 -0.159097 -0.045637 0.062912
486486
"""
487487
if varnames is None:
488-
varnames = [name for name in trace.varnames if not name.endswith('_')]
488+
if include_transformed:
489+
varnames = [name for name in trace.varnames]
490+
else:
491+
varnames = [name for name in trace.varnames if not name.endswith('_')]
489492

490493
funcs = [lambda x: pd.Series(np.mean(x, 0), name='mean'),
491494
lambda x: pd.Series(np.std(x, 0), name='sd'),
@@ -550,7 +553,10 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
550553
551554
"""
552555
if varnames is None:
553-
varnames = [name for name in trace.varnames if not name.endswith('_')]
556+
if include_transformed:
557+
varnames = [name for name in trace.varnames]
558+
else:
559+
varnames = [name for name in trace.varnames if not name.endswith('_')]
554560

555561
stat_summ = _StatSummary(roundto, batches, alpha)
556562
pq_summ = _PosteriorQuantileSummary(roundto, alpha)

pymc3/tests/test_plots.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
matplotlib.use('Agg', warn=False)
33

44
import numpy as np
5+
import pymc3 as pm
56
from .checks import close_to
67

78
from .models import multidimensional_model
8-
from ..plots import traceplot, forestplot, autocorrplot, make_2d
9+
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, make_2d
910
from ..step_methods import Slice, Metropolis
1011
from ..sampling import sample
1112
from ..tuning.scaling import find_hessian
@@ -64,3 +65,17 @@ def test_make_2d():
6465
assert res.shape == (n, 20)
6566
close_to(a[:, 0, 0], res[:, 0], 0)
6667
close_to(a[:, 3, 2], res[:, 2 * 4 + 3], 0)
68+
69+
70+
def test_plots_transformed():
71+
with pm.Model() as model:
72+
pm.Uniform('x', 0, 1)
73+
step = pm.Metropolis()
74+
trace = pm.sample(100, step=step)
75+
76+
assert traceplot(trace).shape == (1, 2)
77+
assert traceplot(trace, plot_transformed=True).shape == (2, 2)
78+
assert autocorrplot(trace).shape == (1, 1)
79+
assert autocorrplot(trace, plot_transformed=True).shape == (2, 1)
80+
assert plot_posterior(trace).shape == (1, )
81+
assert plot_posterior(trace, plot_transformed=True).shape == (2, )

pymc3/tests/test_stats.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,12 @@ def test_value_alignment(self):
366366
else:
367367
vidx = var
368368
npt.assert_equal(val, ds.loc[vidx, 'mean'])
369+
370+
def test_row_names(self):
371+
with Model() as model:
372+
pm.Uniform('x', 0, 1)
373+
step = Metropolis()
374+
trace = pm.sample(100, step=step)
375+
ds = df_summary(trace, batches=3, include_transformed=True)
376+
npt.assert_equal(np.array(['x_interval_', 'x']),
377+
ds.index)

0 commit comments

Comments
 (0)