55import copy
66from functools import partial
77
8+ import jax
89import numpy as np
910
1011from jax import device_put , grad , jacfwd , random , value_and_grad
@@ -470,6 +471,11 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
470471 and site ["fn" ].has_enumerate_support
471472 and not site ["is_observed" ]
472473 }
474+
475+ # All support_enumerates should have the same length to be used in the loop
476+ # Each support is padded with zeros to have the same length
477+ # ravel is used to maintain a consistant behaviour with `support_sizes`
478+
473479 max_length_support_enumerates = max (
474480 (
475481 site ["fn" ].enumerate_support (False ).shape [0 ]
@@ -479,23 +485,25 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
479485 and not site ["is_observed" ]
480486 )
481487 )
482- # All support_enumerates should have the same length to be used in the loop
483- # Each support is padded with zeros to have the same length
484- self ._support_enumerates = np .zeros (
485- (len (self ._support_sizes ), max_length_support_enumerates ), dtype = int
486- )
487- for i , (name , site ) in enumerate (self ._prototype_trace .items ()):
488- if (
489- site ["type" ] == "sample"
490- and site ["fn" ].has_enumerate_support
491- and not site ["is_observed" ]
492- ):
493- self ._support_enumerates [
494- i , : site ["fn" ].enumerate_support (False ).shape [0 ]
495- ] = site ["fn" ].enumerate_support (False )
496- self ._support_enumerates = jnp .asarray (
497- self ._support_enumerates , dtype = jnp .int32
498- )
488+
489+ support_enumerates = {}
490+ for name , support_size in self ._support_sizes .items ():
491+ site = self ._prototype_trace [name ]
492+ enumerate_support = site ["fn" ].enumerate_support (False )
493+ padded_enumerate_support = np .pad (
494+ enumerate_support ,
495+ (0 , max_length_support_enumerates - enumerate_support .shape [0 ]),
496+ )
497+ padded_enumerate_support = np .broadcast_to (
498+ padded_enumerate_support ,
499+ support_size .shape + (max_length_support_enumerates ,),
500+ )
501+ support_enumerates [name ] = padded_enumerate_support
502+
503+ self ._support_enumerates = jax .vmap (
504+ lambda x : ravel_pytree (x )[0 ] , in_axes = 0 , out_axes = 1
505+ )(support_enumerates )
506+
499507 self ._gibbs_sites = [
500508 name
501509 for name , site in self ._prototype_trace .items ()
0 commit comments