Skip to content

Commit be532de

Browse files
aloctavodiaJunpeng Lao
authored andcommitted
Densityplot: add support for discrete variables (#2878)
* densityplot: add support for discrete variables * fix error with xticks, add to release notes * fix typo
1 parent fde52a4 commit be532de

File tree

2 files changed

+48
-36
lines changed

2 files changed

+48
-36
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Optionally it can plot divergences.
1616
- Plots of discrete distributions in the docstrings
1717
- Add logitnormal distribution
18+
- Densityplot: add support for discrete variables
1819

1920
### Fixes
2021

pymc3/plots/densityplot.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
1212
colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None,
1313
textsize=12, plot_transformed=False, ax=None):
1414
"""
15-
Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of
16-
traces. KDE plots are grouped per variable and colors assigned to models.
15+
Generates KDE plots for continuous variables and histograms for discretes ones.
16+
Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped
17+
per variable and colors assigned to models.
1718
1819
Parameters
1920
----------
@@ -32,11 +33,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
3233
Defaults to 'mean'.
3334
colors : list or string, optional
3435
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
35-
If the string is `cycle `, it will automatically choose a color per model from matplolib's
36+
If the string is `cycle`, it will automatically choose a color per model from matplolib's
3637
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
37-
models. Defaults to 'C0' (blueish in most matplotlib styles)
38+
models. Defaults to `cycle`.
3839
outline : boolean
39-
Use a line to draw the truncated KDE and. Defaults to True
40+
Use a line to draw KDEs and histograms. Default to True
4041
hpd_markers : str
4142
A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval.
4243
Defaults to empty string (no marker).
@@ -64,7 +65,7 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
6465
6566
"""
6667
if point_estimate not in ('mean', 'median', None):
67-
raise ValueError("Point estimate should be 'mean' or 'median'")
68+
raise ValueError("Point estimate should be 'mean', 'median' or None")
6869

6970
if not isinstance(trace, (list, tuple)):
7071
trace = [trace]
@@ -77,7 +78,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
7778
else:
7879
models = ['']
7980
elif len(models) != lenght_trace:
80-
raise ValueError("The number of names for the models does not match the number of models")
81+
raise ValueError(
82+
"The number of names for the models does not match the number of models")
8183

8284
lenght_models = len(models)
8385

@@ -97,8 +99,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
9799
if figsize is None:
98100
figsize = (6, len(varnames) * 2)
99101

100-
fig, kplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
101-
kplot = kplot.flatten()
102+
fig, dplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
103+
dplot = dplot.flatten()
102104

103105
for v_idx, vname in enumerate(varnames):
104106
for t_idx, tr in enumerate(trace):
@@ -108,23 +110,24 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
108110
if k > 1:
109111
vec = np.split(vec.T.ravel(), k)
110112
for i in range(k):
111-
_kde_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate,
112-
hpd_markers, outline, shade, kplot[v_idx])
113+
_d_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate,
114+
hpd_markers, outline, shade, dplot[v_idx])
115+
113116
else:
114-
_kde_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate,
115-
hpd_markers, outline, shade, kplot[v_idx])
117+
_d_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate,
118+
hpd_markers, outline, shade, dplot[v_idx])
116119

117120
if lenght_trace > 1:
118121
for m_idx, m in enumerate(models):
119-
kplot[0].plot([], label=m, c=colors[m_idx])
120-
kplot[0].legend(fontsize=textsize)
122+
dplot[0].plot([], label=m, c=colors[m_idx])
123+
dplot[0].legend(fontsize=textsize)
121124

122125
fig.tight_layout()
123126

124-
return kplot
127+
return dplot
125128

126129

127-
def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax):
130+
def _d_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax):
128131
"""
129132
vec : array
130133
1D array from trace
@@ -145,34 +148,42 @@ def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline,
145148
(opaque). Defaults to 0.
146149
ax : matplotlib axes
147150
"""
148-
density, l, u = fast_kde(vec, bw)
149-
x = np.linspace(l, u, len(density))
150-
hpd_ = hpd(vec, alpha)
151-
cut = (x >= hpd_[0]) & (x <= hpd_[1])
152-
153-
xmin = x[cut][0]
154-
xmax = x[cut][-1]
155-
ymin = density[cut][0]
156-
ymax = density[cut][-1]
157-
158-
if outline:
159-
ax.plot(x[cut], density[cut], color=c)
160-
ax.plot([xmin, xmin], [-0.5, ymin], color=c, ls='-')
161-
ax.plot([xmax, xmax], [-0.5, ymax], color=c, ls='-')
151+
if vec.dtype.kind == 'f':
152+
density, l, u = fast_kde(vec)
153+
x = np.linspace(l, u, len(density))
154+
hpd_ = hpd(vec, alpha)
155+
cut = (x >= hpd_[0]) & (x <= hpd_[1])
156+
157+
xmin = x[cut][0]
158+
xmax = x[cut][-1]
159+
ymin = density[cut][0]
160+
ymax = density[cut][-1]
161+
162+
if outline:
163+
ax.plot(x[cut], density[cut], color=c)
164+
ax.plot([xmin, xmin], [-ymin/100, ymin], color=c, ls='-')
165+
ax.plot([xmax, xmax], [-ymax/100, ymax], color=c, ls='-')
166+
167+
if shade:
168+
ax.fill_between(x, density, where=cut, color=c, alpha=shade)
169+
170+
else:
171+
xmin, xmax = hpd(vec, alpha)
172+
bins = range(xmin, xmax+1)
173+
if outline:
174+
ax.hist(vec, bins=bins, color=c, histtype='step')
175+
ax.hist(vec, bins=bins, color=c, alpha=shade)
162176

163177
if hpd_markers:
164178
ax.plot(xmin, 0, 'v', color=c, markeredgecolor='k')
165179
ax.plot(xmax, 0, 'v', color=c, markeredgecolor='k')
166180

167-
if shade:
168-
ax.fill_between(x, density, where=cut, color=c, alpha=shade)
169-
170181
if point_estimate is not None:
171182
if point_estimate == 'mean':
172183
ps = np.mean(vec)
173-
if point_estimate == 'median':
184+
elif point_estimate == 'median':
174185
ps = np.median(vec)
175-
ax.plot(ps, 0, 'o', color=c, markeredgecolor='k')
186+
ax.plot(ps, -0.001, 'o', color=c, markeredgecolor='k')
176187

177188
ax.set_yticks([])
178189
ax.set_title(vname)

0 commit comments

Comments
 (0)