@@ -12,8 +12,9 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
12
12
colors = 'cycle' , outline = True , hpd_markers = '' , shade = 0. , bw = 4.5 , figsize = None ,
13
13
textsize = 12 , plot_transformed = False , ax = None ):
14
14
"""
15
- Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of
16
- traces. KDE plots are grouped per variable and colors assigned to models.
15
+ Generates KDE plots for continuous variables and histograms for discretes ones.
16
+ Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped
17
+ per variable and colors assigned to models.
17
18
18
19
Parameters
19
20
----------
@@ -32,11 +33,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
32
33
Defaults to 'mean'.
33
34
colors : list or string, optional
34
35
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
35
- If the string is `cycle `, it will automatically choose a color per model from matplolib's
36
+ If the string is `cycle`, it will automatically choose a color per model from matplolib's
36
37
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
37
- models. Defaults to 'C0' (blueish in most matplotlib styles)
38
+ models. Defaults to `cycle`.
38
39
outline : boolean
39
- Use a line to draw the truncated KDE and. Defaults to True
40
+ Use a line to draw KDEs and histograms. Default to True
40
41
hpd_markers : str
41
42
A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval.
42
43
Defaults to empty string (no marker).
@@ -64,7 +65,7 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
64
65
65
66
"""
66
67
if point_estimate not in ('mean' , 'median' , None ):
67
- raise ValueError ("Point estimate should be 'mean' or 'median'" )
68
+ raise ValueError ("Point estimate should be 'mean', 'median' or None " )
68
69
69
70
if not isinstance (trace , (list , tuple )):
70
71
trace = [trace ]
@@ -77,7 +78,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
77
78
else :
78
79
models = ['' ]
79
80
elif len (models ) != lenght_trace :
80
- raise ValueError ("The number of names for the models does not match the number of models" )
81
+ raise ValueError (
82
+ "The number of names for the models does not match the number of models" )
81
83
82
84
lenght_models = len (models )
83
85
@@ -97,8 +99,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
97
99
if figsize is None :
98
100
figsize = (6 , len (varnames ) * 2 )
99
101
100
- fig , kplot = plt .subplots (len (varnames ), 1 , squeeze = False , figsize = figsize )
101
- kplot = kplot .flatten ()
102
+ fig , dplot = plt .subplots (len (varnames ), 1 , squeeze = False , figsize = figsize )
103
+ dplot = dplot .flatten ()
102
104
103
105
for v_idx , vname in enumerate (varnames ):
104
106
for t_idx , tr in enumerate (trace ):
@@ -108,23 +110,24 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
108
110
if k > 1 :
109
111
vec = np .split (vec .T .ravel (), k )
110
112
for i in range (k ):
111
- _kde_helper (vec [i ], vname , colors [t_idx ], bw , alpha , point_estimate ,
112
- hpd_markers , outline , shade , kplot [v_idx ])
113
+ _d_helper (vec [i ], vname , colors [t_idx ], bw , alpha , point_estimate ,
114
+ hpd_markers , outline , shade , dplot [v_idx ])
115
+
113
116
else :
114
- _kde_helper (vec , vname , colors [t_idx ], bw , alpha , point_estimate ,
115
- hpd_markers , outline , shade , kplot [v_idx ])
117
+ _d_helper (vec , vname , colors [t_idx ], bw , alpha , point_estimate ,
118
+ hpd_markers , outline , shade , dplot [v_idx ])
116
119
117
120
if lenght_trace > 1 :
118
121
for m_idx , m in enumerate (models ):
119
- kplot [0 ].plot ([], label = m , c = colors [m_idx ])
120
- kplot [0 ].legend (fontsize = textsize )
122
+ dplot [0 ].plot ([], label = m , c = colors [m_idx ])
123
+ dplot [0 ].legend (fontsize = textsize )
121
124
122
125
fig .tight_layout ()
123
126
124
- return kplot
127
+ return dplot
125
128
126
129
127
- def _kde_helper (vec , vname , c , bw , alpha , point_estimate , hpd_markers , outline , shade , ax ):
130
+ def _d_helper (vec , vname , c , bw , alpha , point_estimate , hpd_markers , outline , shade , ax ):
128
131
"""
129
132
vec : array
130
133
1D array from trace
@@ -145,34 +148,42 @@ def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline,
145
148
(opaque). Defaults to 0.
146
149
ax : matplotlib axes
147
150
"""
148
- density , l , u = fast_kde (vec , bw )
149
- x = np .linspace (l , u , len (density ))
150
- hpd_ = hpd (vec , alpha )
151
- cut = (x >= hpd_ [0 ]) & (x <= hpd_ [1 ])
152
-
153
- xmin = x [cut ][0 ]
154
- xmax = x [cut ][- 1 ]
155
- ymin = density [cut ][0 ]
156
- ymax = density [cut ][- 1 ]
157
-
158
- if outline :
159
- ax .plot (x [cut ], density [cut ], color = c )
160
- ax .plot ([xmin , xmin ], [- 0.5 , ymin ], color = c , ls = '-' )
161
- ax .plot ([xmax , xmax ], [- 0.5 , ymax ], color = c , ls = '-' )
151
+ if vec .dtype .kind == 'f' :
152
+ density , l , u = fast_kde (vec )
153
+ x = np .linspace (l , u , len (density ))
154
+ hpd_ = hpd (vec , alpha )
155
+ cut = (x >= hpd_ [0 ]) & (x <= hpd_ [1 ])
156
+
157
+ xmin = x [cut ][0 ]
158
+ xmax = x [cut ][- 1 ]
159
+ ymin = density [cut ][0 ]
160
+ ymax = density [cut ][- 1 ]
161
+
162
+ if outline :
163
+ ax .plot (x [cut ], density [cut ], color = c )
164
+ ax .plot ([xmin , xmin ], [- ymin / 100 , ymin ], color = c , ls = '-' )
165
+ ax .plot ([xmax , xmax ], [- ymax / 100 , ymax ], color = c , ls = '-' )
166
+
167
+ if shade :
168
+ ax .fill_between (x , density , where = cut , color = c , alpha = shade )
169
+
170
+ else :
171
+ xmin , xmax = hpd (vec , alpha )
172
+ bins = range (xmin , xmax + 1 )
173
+ if outline :
174
+ ax .hist (vec , bins = bins , color = c , histtype = 'step' )
175
+ ax .hist (vec , bins = bins , color = c , alpha = shade )
162
176
163
177
if hpd_markers :
164
178
ax .plot (xmin , 0 , 'v' , color = c , markeredgecolor = 'k' )
165
179
ax .plot (xmax , 0 , 'v' , color = c , markeredgecolor = 'k' )
166
180
167
- if shade :
168
- ax .fill_between (x , density , where = cut , color = c , alpha = shade )
169
-
170
181
if point_estimate is not None :
171
182
if point_estimate == 'mean' :
172
183
ps = np .mean (vec )
173
- if point_estimate == 'median' :
184
+ elif point_estimate == 'median' :
174
185
ps = np .median (vec )
175
- ax .plot (ps , 0 , 'o' , color = c , markeredgecolor = 'k' )
186
+ ax .plot (ps , - 0.001 , 'o' , color = c , markeredgecolor = 'k' )
176
187
177
188
ax .set_yticks ([])
178
189
ax .set_title (vname )
0 commit comments