7
7
8
8
9
9
def traceplot (trace , vars = None , figsize = None ,
10
- lines = None , combined = False , grid = True ):
10
+ lines = None , combined = False , grid = True , ax = None ):
11
11
"""Plot samples histograms and values
12
12
13
13
Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
27
27
(default), chains will be plotted separately.
28
28
grid : bool
29
29
Flag for adding gridlines to histogram. Defaults to True.
30
+ ax : axes
31
+ Matplotlib axes. Defaults to None.
30
32
31
33
Returns
32
34
-------
33
35
34
- fig : figure object
36
+ ax : matplotlib axes
35
37
36
38
"""
37
39
import matplotlib .pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
43
45
if figsize is None :
44
46
figsize = (12 , n * 2 )
45
47
46
- fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
48
+ if ax is None :
49
+ fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
50
+ elif ax .shape != (n ,2 ):
51
+ print ('traceplot requires n*2 subplots' )
52
+ return None
47
53
48
54
for i , v in enumerate (vars ):
49
55
for d in trace .get_values (v , combine = combined , squeeze = False ):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
69
75
pass
70
76
71
77
plt .tight_layout ()
72
- return fig
78
+ return ax
73
79
74
80
def histplot_op (ax , data ):
75
81
for i in range (data .shape [1 ]):
@@ -128,23 +134,45 @@ def kde2plot_op(ax, x, y, grid=200):
128
134
extent = [xmin , xmax , ymin , ymax ])
129
135
130
136
131
- def kdeplot (data ):
132
- f , ax = subplots (1 , 1 , squeeze = True )
137
+ def kdeplot (data , ax = None ):
138
+ if ax is None :
139
+ f , ax = subplots (1 , 1 , squeeze = True )
133
140
kdeplot_op (ax , data )
134
- return f
141
+ return ax
135
142
136
143
137
- def kde2plot (x , y , grid = 200 ):
138
- f , ax = subplots (1 , 1 , squeeze = True )
144
+ def kde2plot (x , y , grid = 200 , ax = None ):
145
+ if ax is None :
146
+ f , ax = subplots (1 , 1 , squeeze = True )
139
147
kde2plot_op (ax , x , y , grid )
140
- return f
148
+ return ax
141
149
142
150
143
- def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144
- """Bar plot of the autocorrelation function for a trace"""
151
+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 , ax = None ):
152
+ """Bar plot of the autocorrelation function for a trace
153
+
154
+ Parameters
155
+ ----------
156
+
157
+ trace : result of MCMC run
158
+ vars : list of variable names
159
+ Variables to be plotted, if None all variable are plotted
160
+ max_lag : int
161
+ Maximum lag to calculate autocorrelation. Defaults to 100.
162
+ burn : int
163
+ Number of samples to discard from the beginning of the trace.
164
+ Defaults to 0.
165
+ ax : axes
166
+ Matplotlib axes. Defaults to None.
167
+
168
+ Returns
169
+ -------
170
+
171
+ ax : matplotlib axes
172
+
173
+ """
174
+
145
175
import matplotlib .pyplot as plt
146
- if fontmap is None :
147
- fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
148
176
149
177
if vars is None :
150
178
vars = trace .varnames
@@ -153,13 +181,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153
181
154
182
chains = trace .nchains
155
183
156
- f , ax = plt .subplots (len (vars ), chains , squeeze = False )
184
+ fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
157
185
158
186
max_lag = min (len (trace ) - 1 , max_lag )
159
187
160
188
for i , v in enumerate (vars ):
161
189
for j in range (chains ):
162
- d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
190
+ d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
163
191
164
192
ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
165
193
@@ -169,13 +197,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169
197
170
198
if chains > 1 :
171
199
ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
172
-
173
- # Smaller tick labels
174
- tlabels = plt .gca ().get_xticklabels ()
175
- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176
-
177
- tlabels = plt .gca ().get_yticklabels ()
178
- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
200
+
201
+ return (fig , ax )
179
202
180
203
181
204
def var_str (name , shape ):
@@ -200,7 +223,7 @@ def var_str(name, shape):
200
223
201
224
def forestplot (trace_obj , vars = None , alpha = 0.05 , quartiles = True , rhat = True ,
202
225
main = None , xtitle = None , xrange = None , ylabels = None ,
203
- chain_spacing = 0.05 , vline = 0 ):
226
+ chain_spacing = 0.05 , vline = 0 , gs = None ):
204
227
""" Forest plot (model summary plot)
205
228
206
229
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -245,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
245
268
246
269
vline (optional): numeric
247
270
Location of vertical reference line (defaults to 0).
271
+
272
+ gs : GridSpec
273
+ Matplotlib GridSpec object. Defaults to None.
274
+
275
+ Returns
276
+ -------
277
+
278
+ gs : matplotlib GridSpec
248
279
249
280
"""
250
281
import matplotlib .pyplot as plt
@@ -270,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
270
301
# Number of chains
271
302
chains = None
272
303
273
- # Gridspec
274
- gs = None
275
-
276
304
# Subplots
277
305
interval_plot = None
278
306
rhat_plot = None
0 commit comments