|
| 1 | +from collections import defaultdict |
| 2 | + |
| 3 | +from joblib import Parallel, delayed |
| 4 | +from numpy.random import randint, seed |
| 5 | +from numpy import shape, asarray |
| 6 | + |
1 | 7 | from . import backends
|
2 | 8 | from .backends.base import merge_traces, BaseTrace, MultiTrace
|
3 | 9 | from .backends.ndarray import NDArray
|
4 |
| -from joblib import Parallel, delayed |
5 | 10 | from .model import modelcontext, Point
|
6 | 11 | from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
|
7 | 12 | BinaryGibbsMetropolis, Slice, ElemwiseCategorical, CompoundStep)
|
8 |
| -from .progressbar import progress_bar |
9 |
| -from numpy.random import randint, seed |
10 |
| -from numpy import shape, asarray |
11 |
| -from collections import defaultdict |
| 13 | +from tqdm import tqdm |
12 | 14 |
|
13 | 15 | import sys
|
14 | 16 | sys.setrecursionlimit(10000)
|
@@ -159,11 +161,11 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
|
159 | 161 | progressbar=True, model=None, random_seed=-1):
|
160 | 162 | sampling = _iter_sample(draws, step, start, trace, chain,
|
161 | 163 | tune, model, random_seed)
|
162 |
| - progress = progress_bar(draws) |
| 164 | + if progressbar: |
| 165 | + sampling = tqdm(sampling, total=draws) |
163 | 166 | try:
|
164 |
| - for i, strace in enumerate(sampling): |
165 |
| - if progressbar: |
166 |
| - progress.update(i) |
| 167 | + for strace in sampling: |
| 168 | + pass |
167 | 169 | except KeyboardInterrupt:
|
168 | 170 | strace.close()
|
169 | 171 | return MultiTrace([strace])
|
|
0 commit comments