7
7
8
8
9
9
def traceplot (trace , vars = None , figsize = None ,
10
- lines = None , combined = False , grid = True ):
10
+ lines = None , combined = False , grid = True , ax = None ):
11
11
"""Plot samples histograms and values
12
12
13
13
Parameters
@@ -27,11 +27,13 @@ def traceplot(trace, vars=None, figsize=None,
27
27
(default), chains will be plotted separately.
28
28
grid : bool
29
29
Flag for adding gridlines to histogram. Defaults to True.
30
+ ax : axes
31
+ Matplotlib axes. Defaults to None.
30
32
31
33
Returns
32
34
-------
33
35
34
- fig, ax : tuple of matplotlib figure and axes
36
+ ax : matplotlib axes
35
37
36
38
"""
37
39
import matplotlib .pyplot as plt
@@ -43,7 +45,11 @@ def traceplot(trace, vars=None, figsize=None,
43
45
if figsize is None :
44
46
figsize = (12 , n * 2 )
45
47
46
- fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
48
+ if ax is None :
49
+ fig , ax = plt .subplots (n , 2 , squeeze = False , figsize = figsize )
50
+ elif ax .shape != (n ,2 ):
51
+ print ('traceplot requires n*2 subplots' )
52
+ return None
47
53
48
54
for i , v in enumerate (vars ):
49
55
for d in trace .get_values (v , combine = combined , squeeze = False ):
@@ -69,7 +75,7 @@ def traceplot(trace, vars=None, figsize=None,
69
75
pass
70
76
71
77
plt .tight_layout ()
72
- return ( fig , ax )
78
+ return ax
73
79
74
80
def histplot_op (ax , data ):
75
81
for i in range (data .shape [1 ]):
@@ -128,19 +134,21 @@ def kde2plot_op(ax, x, y, grid=200):
128
134
extent = [xmin , xmax , ymin , ymax ])
129
135
130
136
131
- def kdeplot (data ):
132
- f , ax = subplots (1 , 1 , squeeze = True )
137
+ def kdeplot (data , ax = None ):
138
+ if ax is None :
139
+ f , ax = subplots (1 , 1 , squeeze = True )
133
140
kdeplot_op (ax , data )
134
- return f , ax
141
+ return ax
135
142
136
143
137
- def kde2plot (x , y , grid = 200 ):
138
- f , ax = subplots (1 , 1 , squeeze = True )
144
+ def kde2plot (x , y , grid = 200 , ax = None ):
145
+ if ax is None :
146
+ f , ax = subplots (1 , 1 , squeeze = True )
139
147
kde2plot_op (ax , x , y , grid )
140
- return f , ax
148
+ return ax
141
149
142
150
143
- def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
151
+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 , ax = None ):
144
152
"""Bar plot of the autocorrelation function for a trace
145
153
146
154
Parameters
@@ -149,16 +157,18 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
149
157
trace : result of MCMC run
150
158
vars : list of variable names
151
159
Variables to be plotted, if None all variable are plotted
152
- max_lag: int
160
+ max_lag : int
153
161
Maximum lag to calculate autocorrelation. Defaults to 100.
154
- burn: int
162
+ burn : int
155
163
Number of samples to discard from the beginning of the trace.
156
164
Defaults to 0.
157
-
165
+ ax : axes
166
+ Matplotlib axes. Defaults to None.
167
+
158
168
Returns
159
169
-------
160
170
161
- fig, ax : tuple of matplotlib figure and axes
171
+ ax : matplotlib axes
162
172
163
173
"""
164
174
@@ -213,7 +223,7 @@ def var_str(name, shape):
213
223
214
224
def forestplot (trace_obj , vars = None , alpha = 0.05 , quartiles = True , rhat = True ,
215
225
main = None , xtitle = None , xrange = None , ylabels = None ,
216
- chain_spacing = 0.05 , vline = 0 ):
226
+ chain_spacing = 0.05 , vline = 0 , gs = None ):
217
227
""" Forest plot (model summary plot)
218
228
219
229
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
@@ -258,6 +268,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
258
268
259
269
vline (optional): numeric
260
270
Location of vertical reference line (defaults to 0).
271
+
272
+ gs : GridSpec
273
+ Matplotlib GridSpec object. Defaults to None.
274
+
275
+ Returns
276
+ -------
277
+
278
+ gs : matplotlib GridSpec
261
279
262
280
"""
263
281
import matplotlib .pyplot as plt
@@ -283,9 +301,6 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
283
301
# Number of chains
284
302
chains = None
285
303
286
- # Gridspec
287
- gs = None
288
-
289
304
# Subplots
290
305
interval_plot = None
291
306
rhat_plot = None
0 commit comments