4
4
import pandas as pd
5
5
import itertools
6
6
import sys
7
+ from tqdm import tqdm
7
8
import warnings
8
9
from collections import namedtuple
9
10
from .model import modelcontext
@@ -124,14 +125,18 @@ def dic(trace, model=None):
124
125
return 2 * mean_deviance - deviance_at_mean
125
126
126
127
127
- def _log_post_trace (trace , model ):
128
+ def _log_post_trace (trace , model , progressbar = False ):
128
129
"""Calculate the elementwise log-posterior for the sampled trace.
129
130
130
131
Parameters
131
132
----------
132
133
trace : result of MCMC run
133
134
model : PyMC Model
134
135
Optional model. Default None, taken from context.
136
+ progressbar: bool
137
+ Whether or not to display a progress bar in the command line. The
138
+ bar shows the percentage of completion, the evaluation speed, and
139
+ the estimated time to completion
135
140
136
141
Returns
137
142
-------
@@ -151,11 +156,17 @@ def logp_vals_point(pt):
151
156
152
157
return np .concatenate (logp_vals )
153
158
154
- logp = (logp_vals_point (pt ) for pt in trace )
155
- return np .stack (logp )
159
+ points = tqdm (trace ) if progressbar else trace
160
+
161
+ try :
162
+ logp = (logp_vals_point (pt ) for pt in points )
163
+ return np .stack (logp )
164
+ finally :
165
+ if progressbar :
166
+ points .close ()
156
167
157
168
158
- def waic (trace , model = None , pointwise = False ):
169
+ def waic (trace , model = None , pointwise = False , progressbar = False ):
159
170
"""Calculate the widely available information criterion, its standard error
160
171
and the effective number of parameters of the samples in trace from model.
161
172
Read more theory here - in a paper by some of the leading authorities on
@@ -169,6 +180,10 @@ def waic(trace, model=None, pointwise=False):
169
180
pointwise: bool
170
181
if True the pointwise predictive accuracy will be returned.
171
182
Default False
183
+ progressbar: bool
184
+ Whether or not to display a progress bar in the command line. The
185
+ bar shows the percentage of completion, the evaluation speed, and
186
+ the estimated time to completion
172
187
173
188
Returns
174
189
-------
@@ -180,7 +195,7 @@ def waic(trace, model=None, pointwise=False):
180
195
"""
181
196
model = modelcontext (model )
182
197
183
- log_py = _log_post_trace (trace , model )
198
+ log_py = _log_post_trace (trace , model , progressbar = progressbar )
184
199
if log_py .size == 0 :
185
200
raise ValueError ('The model does not contain observed values.' )
186
201
@@ -208,7 +223,7 @@ def waic(trace, model=None, pointwise=False):
208
223
return WAIC_r (waic , waic_se , p_waic )
209
224
210
225
211
- def loo (trace , model = None , pointwise = False ):
226
+ def loo (trace , model = None , pointwise = False , progressbar = False ):
212
227
"""Calculates leave-one-out (LOO) cross-validation for out of sample predictive
213
228
model fit, following Vehtari et al. (2015). Cross-validation is computed using
214
229
Pareto-smoothed importance sampling (PSIS).
@@ -221,6 +236,10 @@ def loo(trace, model=None, pointwise=False):
221
236
pointwise: bool
222
237
if True the pointwise predictive accuracy will be returned.
223
238
Default False
239
+ progressbar: bool
240
+ Whether or not to display a progress bar in the command line. The
241
+ bar shows the percentage of completion, the evaluation speed, and
242
+ the estimated time to completion
224
243
225
244
Returns
226
245
-------
@@ -232,7 +251,7 @@ def loo(trace, model=None, pointwise=False):
232
251
"""
233
252
model = modelcontext (model )
234
253
235
- log_py = _log_post_trace (trace , model )
254
+ log_py = _log_post_trace (trace , model , progressbar = progressbar )
236
255
if log_py .size == 0 :
237
256
raise ValueError ('The model does not contain observed values.' )
238
257
0 commit comments