Skip to content

Commit ab18282

Browse files
authored
Fix chain detection in progress bar (#1077)
* fixup progress bars with multiple chains * add noteabout ipywidgets to docstring, add underscore to regex * make CHAIN_RE private
1 parent 86c65cf commit ab18282

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

numpyro/infer/mcmc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ class MCMC(object):
204204
.. note:: Setting `progress_bar=False` will improve the speed for many cases. But it might
205205
require more memory than the other option.
206206
207+
.. note:: If setting `num_chains` greater than `1` in a Jupyter Notebook, then you will need to
208+
have installed `ipywidgets <https://ipywidgets.readthedocs.io/en/latest/user_install.html>`_
209+
in the environment from which you launced Jupyter in order for the progress bars to render
210+
correctly.
211+
207212
:param MCMCKernel sampler: an instance of :class:`~numpyro.infer.mcmc.MCMCKernel` that
208213
determines the sampler for running MCMC. Currently, only :class:`~numpyro.infer.hmc.HMC`
209214
and :class:`~numpyro.infer.hmc.NUTS` are available.

numpyro/util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from jax.tree_util import tree_flatten, tree_map
2222

2323
_DISABLE_CONTROL_FLOW_PRIM = False
24+
_CHAIN_RE = re.compile(r"(?<=_)\d+$") # e.g. get '3' from 'TFRT_CPU_3'
2425

2526

2627
def set_rng_seed(rng_seed):
@@ -189,12 +190,16 @@ def progress_bar_factory(num_samples, num_chains):
189190
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)
190191

191192
def _update_tqdm(arg, transform, device):
192-
chain = int(str(device)[4:])
193+
chain_match = _CHAIN_RE.search(str(device))
194+
assert chain_match
195+
chain = int(chain_match.group())
193196
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
194197
tqdm_bars[chain].update(arg)
195198

196199
def _close_tqdm(arg, transform, device):
197-
chain = int(str(device)[4:])
200+
chain_match = _CHAIN_RE.search(str(device))
201+
assert chain_match
202+
chain = int(chain_match.group())
198203
tqdm_bars[chain].update(arg)
199204
finished_chains.append(chain)
200205
if len(finished_chains) == num_chains:

0 commit comments

Comments
 (0)