We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9ca63da commit 4876f87Copy full SHA for 4876f87
numpyro/util.py
@@ -6,6 +6,7 @@
6
import os
7
import random
8
import re
9
+import warnings
10
11
import numpy as np
12
import tqdm
@@ -300,6 +301,13 @@ def fori_collect(
300
301
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
302
start_idx = lower + (upper - lower) % thinning
303
num_chains = progbar_opts.pop("num_chains", 1)
304
+ # host_callback does not work yet with multi-GPU platforms
305
+ # See: https://github.com/google/jax/issues/6447
306
+ if num_chains > 1 and jax.default_backend() == "gpu":
307
+ warnings.warn(
308
+ "We will disable progress bar because it does not work yet on multi-GPUs platforms."
309
+ )
310
+ progbar = False
311
312
@cached_by(fori_collect, body_fun, transform)
313
def _body_fn(i, vals):
0 commit comments