Skip to content

Commit 7860e36

Browse files
tarefehiepsitillahoffmannJuan Orduz
authored
Fix padding calculation (#1993)
* Fix pad calculation * Update paths to data files (#1996) * Update paths to data files * address numerical issues due to new jax release * fix the size of the new SP500 file * Add `sum_sites` option to sum loss over sites or return as dict. (#1995) * Add `sum_sites` option to sum loss over sites or return as `dict`. * Add missing newline in warning. * Improve rng default in the `contrib/model.py` module (#1992) * fix * rm comment * alternative rng * suggestion --------- Co-authored-by: Du Phan <[email protected]> Co-authored-by: Till Hoffmann <[email protected]> Co-authored-by: Juan Orduz <[email protected]>
1 parent f5454aa commit 7860e36

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

numpyro/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def soft_vmap(
438438
# it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
439439
chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size)
440440
if chunk_size > 1:
441-
pad = chunk_size - (batch_size % chunk_size)
441+
pad = chunk_size - batch_size % chunk_size if batch_size % chunk_size else 0
442442
xs = jax.tree.map(
443443
lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs
444444
)

0 commit comments

Comments
 (0)