Skip to content

Commit e14eea7

Browse files
committed
updating the logical using ravel to maintain a consistant behaviour
1 parent 0cd448f commit e14eea7

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
from functools import partial
77

8+
import jax
89
import numpy as np
910

1011
from jax import device_put, grad, jacfwd, random, value_and_grad
@@ -470,6 +471,11 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
470471
and site["fn"].has_enumerate_support
471472
and not site["is_observed"]
472473
}
474+
475+
# All support_enumerates should have the same length to be used in the loop
476+
# Each support is padded with zeros to have the same length
477+
# ravel is used to maintain a consistant behaviour with `support_sizes`
478+
473479
max_length_support_enumerates = max(
474480
(
475481
site["fn"].enumerate_support(False).shape[0]
@@ -479,23 +485,25 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
479485
and not site["is_observed"]
480486
)
481487
)
482-
# All support_enumerates should have the same length to be used in the loop
483-
# Each support is padded with zeros to have the same length
484-
self._support_enumerates = np.zeros(
485-
(len(self._support_sizes), max_length_support_enumerates), dtype=int
486-
)
487-
for i, (name, site) in enumerate(self._prototype_trace.items()):
488-
if (
489-
site["type"] == "sample"
490-
and site["fn"].has_enumerate_support
491-
and not site["is_observed"]
492-
):
493-
self._support_enumerates[
494-
i, : site["fn"].enumerate_support(False).shape[0]
495-
] = site["fn"].enumerate_support(False)
496-
self._support_enumerates = jnp.asarray(
497-
self._support_enumerates, dtype=jnp.int32
498-
)
488+
489+
support_enumerates = {}
490+
for name, support_size in self._support_sizes.items():
491+
site = self._prototype_trace[name]
492+
enumerate_support = site["fn"].enumerate_support(False)
493+
padded_enumerate_support = np.pad(
494+
enumerate_support,
495+
(0, max_length_support_enumerates - enumerate_support.shape[0]),
496+
)
497+
padded_enumerate_support = np.broadcast_to(
498+
padded_enumerate_support,
499+
support_size.shape + (max_length_support_enumerates,),
500+
)
501+
support_enumerates[name] = padded_enumerate_support
502+
503+
self._support_enumerates = jax.vmap(
504+
lambda x: ravel_pytree(x)[0] , in_axes=0, out_axes=1
505+
)(support_enumerates)
506+
499507
self._gibbs_sites = [
500508
name
501509
for name, site in self._prototype_trace.items()

0 commit comments

Comments
 (0)