Skip to content

Commit 4876f87

Browse files
authored
Disable progbar for multi-GPU platform (#1048)
* disable progbar for multi-GPU platform * add a warning for progress bar on GPU * add missing import
1 parent 9ca63da commit 4876f87

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

numpyro/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import random
88
import re
9+
import warnings
910

1011
import numpy as np
1112
import tqdm
@@ -300,6 +301,13 @@ def fori_collect(
300301
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
301302
start_idx = lower + (upper - lower) % thinning
302303
num_chains = progbar_opts.pop("num_chains", 1)
304+
# host_callback does not work yet with multi-GPU platforms
305+
# See: https://github.com/google/jax/issues/6447
306+
if num_chains > 1 and jax.default_backend() == "gpu":
307+
warnings.warn(
308+
"We will disable progress bar because it does not work yet on multi-GPUs platforms."
309+
)
310+
progbar = False
303311

304312
@cached_by(fori_collect, body_fun, transform)
305313
def _body_fn(i, vals):

0 commit comments

Comments
 (0)