Skip to content

Commit 844cc67

Browse files
Merge pull request #2456 from pymc-devs/waic-add-progressbar
Add progressbar option to WAIC
2 parents d3cb804 + 4d5997a commit 844cc67

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

pymc3/stats.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import itertools
66
import sys
7+
from tqdm import tqdm
78
import warnings
89
from collections import namedtuple
910
from .model import modelcontext
@@ -124,14 +125,18 @@ def dic(trace, model=None):
124125
return 2 * mean_deviance - deviance_at_mean
125126

126127

127-
def _log_post_trace(trace, model):
128+
def _log_post_trace(trace, model, progressbar=False):
128129
"""Calculate the elementwise log-posterior for the sampled trace.
129130
130131
Parameters
131132
----------
132133
trace : result of MCMC run
133134
model : PyMC Model
134135
Optional model. Default None, taken from context.
136+
progressbar: bool
137+
Whether or not to display a progress bar in the command line. The
138+
bar shows the percentage of completion, the evaluation speed, and
139+
the estimated time to completion
135140
136141
Returns
137142
-------
@@ -151,11 +156,17 @@ def logp_vals_point(pt):
151156

152157
return np.concatenate(logp_vals)
153158

154-
logp = (logp_vals_point(pt) for pt in trace)
155-
return np.stack(logp)
159+
points = tqdm(trace) if progressbar else trace
160+
161+
try:
162+
logp = (logp_vals_point(pt) for pt in points)
163+
return np.stack(logp)
164+
finally:
165+
if progressbar:
166+
points.close()
156167

157168

158-
def waic(trace, model=None, pointwise=False):
169+
def waic(trace, model=None, pointwise=False, progressbar=False):
159170
"""Calculate the widely available information criterion, its standard error
160171
and the effective number of parameters of the samples in trace from model.
161172
Read more theory here - in a paper by some of the leading authorities on
@@ -169,6 +180,10 @@ def waic(trace, model=None, pointwise=False):
169180
pointwise: bool
170181
if True the pointwise predictive accuracy will be returned.
171182
Default False
183+
progressbar: bool
184+
Whether or not to display a progress bar in the command line. The
185+
bar shows the percentage of completion, the evaluation speed, and
186+
the estimated time to completion
172187
173188
Returns
174189
-------
@@ -180,7 +195,7 @@ def waic(trace, model=None, pointwise=False):
180195
"""
181196
model = modelcontext(model)
182197

183-
log_py = _log_post_trace(trace, model)
198+
log_py = _log_post_trace(trace, model, progressbar=progressbar)
184199
if log_py.size == 0:
185200
raise ValueError('The model does not contain observed values.')
186201

@@ -208,7 +223,7 @@ def waic(trace, model=None, pointwise=False):
208223
return WAIC_r(waic, waic_se, p_waic)
209224

210225

211-
def loo(trace, model=None, pointwise=False):
226+
def loo(trace, model=None, pointwise=False, progressbar=False):
212227
"""Calculates leave-one-out (LOO) cross-validation for out of sample predictive
213228
model fit, following Vehtari et al. (2015). Cross-validation is computed using
214229
Pareto-smoothed importance sampling (PSIS).
@@ -221,6 +236,10 @@ def loo(trace, model=None, pointwise=False):
221236
pointwise: bool
222237
if True the pointwise predictive accuracy will be returned.
223238
Default False
239+
progressbar: bool
240+
Whether or not to display a progress bar in the command line. The
241+
bar shows the percentage of completion, the evaluation speed, and
242+
the estimated time to completion
224243
225244
Returns
226245
-------
@@ -232,7 +251,7 @@ def loo(trace, model=None, pointwise=False):
232251
"""
233252
model = modelcontext(model)
234253

235-
log_py = _log_post_trace(trace, model)
254+
log_py = _log_post_trace(trace, model, progressbar=progressbar)
236255
if log_py.size == 0:
237256
raise ValueError('The model does not contain observed values.')
238257

0 commit comments

Comments
 (0)