@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
31
31
Returns
32
32
-------
33
33
34
- fig : figure object
34
+ fig, ax : tuple of matplotlib figure and axes
35
35
36
36
"""
37
37
import matplotlib .pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
69
69
pass
70
70
71
71
plt .tight_layout ()
72
- return fig
72
+ return ( fig , ax )
73
73
74
74
def histplot_op (ax , data ):
75
75
for i in range (data .shape [1 ]):
@@ -131,20 +131,38 @@ def kde2plot_op(ax, x, y, grid=200):
131
131
def kdeplot (data ):
132
132
f , ax = subplots (1 , 1 , squeeze = True )
133
133
kdeplot_op (ax , data )
134
- return f
134
+ return f , ax
135
135
136
136
137
137
def kde2plot (x , y , grid = 200 ):
138
138
f , ax = subplots (1 , 1 , squeeze = True )
139
139
kde2plot_op (ax , x , y , grid )
140
- return f
140
+ return f , ax
141
141
142
142
143
- def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144
- """Bar plot of the autocorrelation function for a trace"""
143
+ def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
144
+ """Bar plot of the autocorrelation function for a trace
145
+
146
+ Parameters
147
+ ----------
148
+
149
+ trace : result of MCMC run
150
+ vars : list of variable names
151
+ Variables to be plotted, if None all variable are plotted
152
+ max_lag: int
153
+ Maximum lag to calculate autocorrelation. Defaults to 100.
154
+ burn: int
155
+ Number of samples to discard from the beginning of the trace.
156
+ Defaults to 0.
157
+
158
+ Returns
159
+ -------
160
+
161
+ fig, ax : tuple of matplotlib figure and axes
162
+
163
+ """
164
+
145
165
import matplotlib .pyplot as plt
146
- if fontmap is None :
147
- fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
148
166
149
167
if vars is None :
150
168
vars = trace .varnames
@@ -153,13 +171,13 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
153
171
154
172
chains = trace .nchains
155
173
156
- f , ax = plt .subplots (len (vars ), chains , squeeze = False )
174
+ fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
157
175
158
176
max_lag = min (len (trace ) - 1 , max_lag )
159
177
160
178
for i , v in enumerate (vars ):
161
179
for j in range (chains ):
162
- d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
180
+ d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
163
181
164
182
ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
165
183
@@ -169,13 +187,8 @@ def autocorrplot(trace, vars=None, fontmap=None, max_lag=100,burn=0, thin=1):
169
187
170
188
if chains > 1 :
171
189
ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
172
-
173
- # Smaller tick labels
174
- tlabels = plt .gca ().get_xticklabels ()
175
- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176
-
177
- tlabels = plt .gca ().get_yticklabels ()
178
- plt .setp (tlabels , 'fontsize' , fontmap [1 ])
190
+
191
+ return (fig , ax )
179
192
180
193
181
194
def var_str (name , shape ):
0 commit comments