File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -204,6 +204,11 @@ class MCMC(object):
204204 .. note:: Setting `progress_bar=False` will improve the speed for many cases. But it might
205205 require more memory than the other option.
206206
207+ .. note:: If setting `num_chains` greater than `1` in a Jupyter Notebook, then you will need to
208+ have installed `ipywidgets <https://ipywidgets.readthedocs.io/en/latest/user_install.html>`_
209+ in the environment from which you launced Jupyter in order for the progress bars to render
210+ correctly.
211+
207212 :param MCMCKernel sampler: an instance of :class:`~numpyro.infer.mcmc.MCMCKernel` that
208213 determines the sampler for running MCMC. Currently, only :class:`~numpyro.infer.hmc.HMC`
209214 and :class:`~numpyro.infer.hmc.NUTS` are available.
Original file line number Diff line number Diff line change 2121from jax .tree_util import tree_flatten , tree_map
2222
2323_DISABLE_CONTROL_FLOW_PRIM = False
24+ _CHAIN_RE = re .compile (r"(?<=_)\d+$" ) # e.g. get '3' from 'TFRT_CPU_3'
2425
2526
2627def set_rng_seed (rng_seed ):
@@ -189,12 +190,16 @@ def progress_bar_factory(num_samples, num_chains):
189190 tqdm_bars [chain ].set_description ("Compiling.. " , refresh = True )
190191
191192 def _update_tqdm (arg , transform , device ):
192- chain = int (str (device )[4 :])
193+ chain_match = _CHAIN_RE .search (str (device ))
194+ assert chain_match
195+ chain = int (chain_match .group ())
193196 tqdm_bars [chain ].set_description (f"Running chain { chain } " , refresh = False )
194197 tqdm_bars [chain ].update (arg )
195198
196199 def _close_tqdm (arg , transform , device ):
197- chain = int (str (device )[4 :])
200+ chain_match = _CHAIN_RE .search (str (device ))
201+ assert chain_match
202+ chain = int (chain_match .group ())
198203 tqdm_bars [chain ].update (arg )
199204 finished_chains .append (chain )
200205 if len (finished_chains ) == num_chains :
You can’t perform that action at this time.
0 commit comments