Skip to content

Commit b19a83d

Browse files
authored
enable progressbar for multi-gpu (#1849)
1 parent d61f15c commit b19a83d

File tree

1 file changed

+0
-8
lines changed

1 file changed

+0
-8
lines changed

numpyro/util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,6 @@ def fori_collect(
325325
init_val_transformed = transform(init_val)
326326
start_idx = lower + (upper - lower) % thinning
327327
num_chains = progbar_opts.pop("num_chains", 1)
328-
# host_callback does not work yet with multi-GPU platforms
329-
# See: https://github.com/google/jax/issues/6447
330-
if num_chains > 1 and jax.default_backend() == "gpu":
331-
warnings.warn(
332-
"We will disable progress bar because it does not work yet on multi-GPUs platforms.",
333-
stacklevel=find_stack_level(),
334-
)
335-
progbar = False
336328

337329
@partial(maybe_jit, donate_argnums=2)
338330
@cached_by(fori_collect, body_fun, transform)

0 commit comments

Comments
 (0)