We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d61f15c commit b19a83dCopy full SHA for b19a83d
numpyro/util.py
@@ -325,14 +325,6 @@ def fori_collect(
325
init_val_transformed = transform(init_val)
326
start_idx = lower + (upper - lower) % thinning
327
num_chains = progbar_opts.pop("num_chains", 1)
328
- # host_callback does not work yet with multi-GPU platforms
329
- # See: https://github.com/google/jax/issues/6447
330
- if num_chains > 1 and jax.default_backend() == "gpu":
331
- warnings.warn(
332
- "We will disable progress bar because it does not work yet on multi-GPUs platforms.",
333
- stacklevel=find_stack_level(),
334
- )
335
- progbar = False
336
337
@partial(maybe_jit, donate_argnums=2)
338
@cached_by(fori_collect, body_fun, transform)
0 commit comments