@@ -31,7 +31,7 @@ def traceplot(trace, vars=None, figsize=None,
31
31
Returns
32
32
-------
33
33
34
- fig, ax : tuple of matplotlib figure and axes
34
+ fig : figure object
35
35
36
36
"""
37
37
import matplotlib .pyplot as plt
@@ -69,7 +69,7 @@ def traceplot(trace, vars=None, figsize=None,
69
69
pass
70
70
71
71
plt .tight_layout ()
72
- return ( fig , ax )
72
+ return fig
73
73
74
74
def histplot_op (ax , data ):
75
75
for i in range (data .shape [1 ]):
@@ -131,38 +131,20 @@ def kde2plot_op(ax, x, y, grid=200):
131
131
def kdeplot (data ):
132
132
f , ax = subplots (1 , 1 , squeeze = True )
133
133
kdeplot_op (ax , data )
134
- return f , ax
134
+ return f
135
135
136
136
137
137
def kde2plot (x , y , grid = 200 ):
138
138
f , ax = subplots (1 , 1 , squeeze = True )
139
139
kde2plot_op (ax , x , y , grid )
140
- return f , ax
140
+ return f
141
141
142
142
143
- def autocorrplot (trace , vars = None , max_lag = 100 , burn = 0 ):
144
- """Bar plot of the autocorrelation function for a trace
145
-
146
- Parameters
147
- ----------
148
-
149
- trace : result of MCMC run
150
- vars : list of variable names
151
- Variables to be plotted, if None all variable are plotted
152
- max_lag: int
153
- Maximum lag to calculate autocorrelation. Defaults to 100.
154
- burn: int
155
- Number of samples to discard from the beginning of the trace.
156
- Defaults to 0.
157
-
158
- Returns
159
- -------
160
-
161
- fig, ax : tuple of matplotlib figure and axes
162
-
163
- """
164
-
143
+ def autocorrplot (trace , vars = None , fontmap = None , max_lag = 100 ,burn = 0 , thin = 1 ):
144
+ """Bar plot of the autocorrelation function for a trace"""
165
145
import matplotlib .pyplot as plt
146
+ if fontmap is None :
147
+ fontmap = {1 : 10 , 2 : 8 , 3 : 6 , 4 : 5 , 5 : 4 }
166
148
167
149
if vars is None :
168
150
vars = trace .varnames
@@ -171,13 +153,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
171
153
172
154
chains = trace .nchains
173
155
174
- fig , ax = plt .subplots (len (vars ), chains , squeeze = False )
156
+ f , ax = plt .subplots (len (vars ), chains , squeeze = False )
175
157
176
158
max_lag = min (len (trace ) - 1 , max_lag )
177
159
178
160
for i , v in enumerate (vars ):
179
161
for j in range (chains ):
180
- d = np .squeeze (trace .get_values (v , chains = [j ], burn = burn ))
162
+ d = np .squeeze (trace .get_values (v , chains = [j ],burn = burn , thin = thin ))
181
163
182
164
ax [i , j ].acorr (d , detrend = plt .mlab .detrend_mean , maxlags = max_lag )
183
165
@@ -187,8 +169,13 @@ def autocorrplot(trace, vars=None, max_lag=100, burn=0):
187
169
188
170
if chains > 1 :
189
171
ax [i , j ].set_title ("chain {0}" .format (j + 1 ))
190
-
191
- return (fig , ax )
172
+
173
+ # Smaller tick labels
174
+ tlabels = plt .gca ().get_xticklabels ()
175
+ plt .setp (tlabels , 'fontsize' , fontmap [1 ])
176
+
177
+ tlabels = plt .gca ().get_yticklabels ()
178
+ plt .setp (tlabels , 'fontsize' , fontmap [1 ])
192
179
193
180
194
181
def var_str (name , shape ):
0 commit comments