@@ -192,12 +192,22 @@ def __getstate__(self):
192192
193193
194194def _discrete_gibbs_proposal_body_fn (
195- z_init_flat , unravel_fn , pe_init , potential_fn , idx , i , val
195+ z_init_flat ,
196+ unravel_fn ,
197+ pe_init ,
198+ potential_fn ,
199+ idx ,
200+ i ,
201+ val ,
202+ support_size ,
203+ support_enumerate ,
196204):
197205 rng_key , z , pe , log_weight_sum = val
198206 rng_key , rng_transition = random .split (rng_key )
199- proposal = jnp .where (i >= z_init_flat [idx ], i + 1 , i )
200- z_new_flat = z_init_flat .at [idx ].set (proposal )
207+ proposal_index = jnp .where (
208+ support_enumerate [i ] == z_init_flat [idx ], support_size - 1 , i
209+ )
210+ z_new_flat = z_init_flat .at [idx ].set (support_enumerate [proposal_index ])
201211 z_new = unravel_fn (z_new_flat )
202212 pe_new = potential_fn (z_new )
203213 log_weight_new = pe_init - pe_new
@@ -216,7 +226,9 @@ def _discrete_gibbs_proposal_body_fn(
216226 return rng_key , z , pe , log_weight_sum
217227
218228
219- def _discrete_gibbs_proposal (rng_key , z_discrete , pe , potential_fn , idx , support_size ):
229+ def _discrete_gibbs_proposal (
230+ rng_key , z_discrete , pe , potential_fn , idx , support_size , support_enumerate
231+ ):
220232 # idx: current index of `z_discrete_flat` to update
221233 # support_size: support size of z_discrete at the index idx
222234
@@ -234,6 +246,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
234246 pe ,
235247 potential_fn ,
236248 idx ,
249+ support_size = support_size ,
250+ support_enumerate = support_enumerate ,
237251 )
238252 init_val = (rng_key , z_discrete , pe , jnp .array (0.0 ))
239253 rng_key , z_new , pe_new , _ = fori_loop (0 , support_size - 1 , body_fn , init_val )
@@ -242,7 +256,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
242256
243257
244258def _discrete_modified_gibbs_proposal (
245- rng_key , z_discrete , pe , potential_fn , idx , support_size , stay_prob = 0.0
259+ rng_key ,
260+ z_discrete ,
261+ pe ,
262+ potential_fn ,
263+ idx ,
264+ support_size ,
265+ support_enumerate ,
266+ stay_prob = 0.0 ,
246267):
247268 assert isinstance (stay_prob , float ) and stay_prob >= 0.0 and stay_prob < 1
248269 z_discrete_flat , unravel_fn = ravel_pytree (z_discrete )
@@ -253,6 +274,8 @@ def _discrete_modified_gibbs_proposal(
253274 pe ,
254275 potential_fn ,
255276 idx ,
277+ support_size = support_size ,
278+ support_enumerate = support_enumerate ,
256279 )
257280 # like gibbs_step but here, weight of the current value is 0
258281 init_val = (rng_key , z_discrete , pe , jnp .array (- jnp .inf ))
@@ -276,28 +299,41 @@ def _discrete_modified_gibbs_proposal(
276299 return rng_key , z_new , pe_new , log_accept_ratio
277300
278301
279- def _discrete_rw_proposal (rng_key , z_discrete , pe , potential_fn , idx , support_size ):
302+ def _discrete_rw_proposal (
303+ rng_key , z_discrete , pe , potential_fn , idx , support_size , support_enumerate
304+ ):
280305 rng_key , rng_proposal = random .split (rng_key , 2 )
281306 z_discrete_flat , unravel_fn = ravel_pytree (z_discrete )
282307
283308 proposal = random .randint (rng_proposal , (), minval = 0 , maxval = support_size )
284- z_new_flat = z_discrete_flat .at [idx ].set (proposal )
309+ z_new_flat = z_discrete_flat .at [idx ].set (support_enumerate [ proposal ] )
285310 z_new = unravel_fn (z_new_flat )
286311 pe_new = potential_fn (z_new )
287312 log_accept_ratio = pe - pe_new
288313 return rng_key , z_new , pe_new , log_accept_ratio
289314
290315
291316def _discrete_modified_rw_proposal (
292- rng_key , z_discrete , pe , potential_fn , idx , support_size , stay_prob = 0.0
317+ rng_key ,
318+ z_discrete ,
319+ pe ,
320+ potential_fn ,
321+ idx ,
322+ support_size ,
323+ support_enumerate ,
324+ stay_prob = 0.0 ,
293325):
294326 assert isinstance (stay_prob , float ) and stay_prob >= 0.0 and stay_prob < 1
295327 rng_key , rng_proposal , rng_stay = random .split (rng_key , 3 )
296328 z_discrete_flat , unravel_fn = ravel_pytree (z_discrete )
297329
298330 i = random .randint (rng_proposal , (), minval = 0 , maxval = support_size - 1 )
299- proposal = jnp .where (i >= z_discrete_flat [idx ], i + 1 , i )
300- proposal = jnp .where (random .bernoulli (rng_stay , stay_prob ), idx , proposal )
331+ proposal_index = jnp .where (
332+ support_enumerate [i ] == z_discrete_flat [idx ], support_size - 1 , i
333+ )
334+ proposal = jnp .where (
335+ random .bernoulli (rng_stay , stay_prob ), idx , support_enumerate [proposal_index ]
336+ )
301337 z_new_flat = z_discrete_flat .at [idx ].set (proposal )
302338 z_new = unravel_fn (z_new_flat )
303339 pe_new = potential_fn (z_new )
@@ -434,13 +470,32 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
434470 and site ["fn" ].has_enumerate_support
435471 and not site ["is_observed" ]
436472 }
437- self ._support_enumerates = {
438- name : site ["fn" ].enumerate_support (False )
439- for name , site in self ._prototype_trace .items ()
440- if site ["type" ] == "sample"
441- and site ["fn" ].has_enumerate_support
442- and not site ["is_observed" ]
443- }
473+ max_length_support_enumerates = max (
474+ (
475+ site ["fn" ].enumerate_support (False ).shape [0 ]
476+ for site in self ._prototype_trace .values ()
477+ if site ["type" ] == "sample"
478+ and site ["fn" ].has_enumerate_support
479+ and not site ["is_observed" ]
480+ )
481+ )
482+ # All support_enumerates should have the same length to be used in the loop
483+ # Each support is padded with zeros to have the same length
484+ self ._support_enumerates = np .zeros (
485+ (len (self ._support_sizes ), max_length_support_enumerates ), dtype = int
486+ )
487+ for i , (name , site ) in enumerate (self ._prototype_trace .items ()):
488+ if (
489+ site ["type" ] == "sample"
490+ and site ["fn" ].has_enumerate_support
491+ and not site ["is_observed" ]
492+ ):
493+ self ._support_enumerates [
494+ i , : site ["fn" ].enumerate_support (False ).shape [0 ]
495+ ] = site ["fn" ].enumerate_support (False )
496+ self ._support_enumerates = jnp .asarray (
497+ self ._support_enumerates , dtype = jnp .int32
498+ )
444499 self ._gibbs_sites = [
445500 name
446501 for name , site in self ._prototype_trace .items ()
0 commit comments