Skip to content

Commit 0cd448f

Browse files
committed
updated with enumerate support as padded zeros arrays
1 parent 59de90f commit 0cd448f

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,22 @@ def __getstate__(self):
192192

193193

194194
def _discrete_gibbs_proposal_body_fn(
195-
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
195+
z_init_flat,
196+
unravel_fn,
197+
pe_init,
198+
potential_fn,
199+
idx,
200+
i,
201+
val,
202+
support_size,
203+
support_enumerate,
196204
):
197205
rng_key, z, pe, log_weight_sum = val
198206
rng_key, rng_transition = random.split(rng_key)
199-
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
200-
z_new_flat = z_init_flat.at[idx].set(proposal)
207+
proposal_index = jnp.where(
208+
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
209+
)
210+
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
201211
z_new = unravel_fn(z_new_flat)
202212
pe_new = potential_fn(z_new)
203213
log_weight_new = pe_init - pe_new
@@ -216,7 +226,9 @@ def _discrete_gibbs_proposal_body_fn(
216226
return rng_key, z, pe, log_weight_sum
217227

218228

219-
def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
229+
def _discrete_gibbs_proposal(
230+
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
231+
):
220232
# idx: current index of `z_discrete_flat` to update
221233
# support_size: support size of z_discrete at the index idx
222234

@@ -234,6 +246,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
234246
pe,
235247
potential_fn,
236248
idx,
249+
support_size=support_size,
250+
support_enumerate=support_enumerate,
237251
)
238252
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
239253
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
@@ -242,7 +256,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
242256

243257

244258
def _discrete_modified_gibbs_proposal(
245-
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
259+
rng_key,
260+
z_discrete,
261+
pe,
262+
potential_fn,
263+
idx,
264+
support_size,
265+
support_enumerate,
266+
stay_prob=0.0,
246267
):
247268
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
248269
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
@@ -253,6 +274,8 @@ def _discrete_modified_gibbs_proposal(
253274
pe,
254275
potential_fn,
255276
idx,
277+
support_size=support_size,
278+
support_enumerate=support_enumerate,
256279
)
257280
# like gibbs_step but here, weight of the current value is 0
258281
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
@@ -276,28 +299,41 @@ def _discrete_modified_gibbs_proposal(
276299
return rng_key, z_new, pe_new, log_accept_ratio
277300

278301

279-
def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
302+
def _discrete_rw_proposal(
303+
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
304+
):
280305
rng_key, rng_proposal = random.split(rng_key, 2)
281306
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
282307

283308
proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
284-
z_new_flat = z_discrete_flat.at[idx].set(proposal)
309+
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
285310
z_new = unravel_fn(z_new_flat)
286311
pe_new = potential_fn(z_new)
287312
log_accept_ratio = pe - pe_new
288313
return rng_key, z_new, pe_new, log_accept_ratio
289314

290315

291316
def _discrete_modified_rw_proposal(
292-
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
317+
rng_key,
318+
z_discrete,
319+
pe,
320+
potential_fn,
321+
idx,
322+
support_size,
323+
support_enumerate,
324+
stay_prob=0.0,
293325
):
294326
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
295327
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
296328
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
297329

298330
i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
299-
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
300-
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
331+
proposal_index = jnp.where(
332+
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
333+
)
334+
proposal = jnp.where(
335+
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
336+
)
301337
z_new_flat = z_discrete_flat.at[idx].set(proposal)
302338
z_new = unravel_fn(z_new_flat)
303339
pe_new = potential_fn(z_new)
@@ -434,13 +470,32 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
434470
and site["fn"].has_enumerate_support
435471
and not site["is_observed"]
436472
}
437-
self._support_enumerates = {
438-
name: site["fn"].enumerate_support(False)
439-
for name, site in self._prototype_trace.items()
440-
if site["type"] == "sample"
441-
and site["fn"].has_enumerate_support
442-
and not site["is_observed"]
443-
}
473+
max_length_support_enumerates = max(
474+
(
475+
site["fn"].enumerate_support(False).shape[0]
476+
for site in self._prototype_trace.values()
477+
if site["type"] == "sample"
478+
and site["fn"].has_enumerate_support
479+
and not site["is_observed"]
480+
)
481+
)
482+
# All support_enumerates should have the same length to be used in the loop
483+
# Each support is padded with zeros to have the same length
484+
self._support_enumerates = np.zeros(
485+
(len(self._support_sizes), max_length_support_enumerates), dtype=int
486+
)
487+
for i, (name, site) in enumerate(self._prototype_trace.items()):
488+
if (
489+
site["type"] == "sample"
490+
and site["fn"].has_enumerate_support
491+
and not site["is_observed"]
492+
):
493+
self._support_enumerates[
494+
i, : site["fn"].enumerate_support(False).shape[0]
495+
] = site["fn"].enumerate_support(False)
496+
self._support_enumerates = jnp.asarray(
497+
self._support_enumerates, dtype=jnp.int32
498+
)
444499
self._gibbs_sites = [
445500
name
446501
for name, site in self._prototype_trace.items()

numpyro/infer/mixed_hmc.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def update_discrete(
139139
partial(potential_fn, z_hmc=hmc_state.z),
140140
idx,
141141
self._support_sizes_flat[idx],
142+
self._support_enumerates[idx],
142143
)
143144
# Algo 1, line 20: depending on reject or refract, we will update
144145
# the discrete variable and its corresponding kinetic energy. In case of
@@ -302,11 +303,6 @@ def body_fn(i, vals):
302303
adapt_state=adapt_state,
303304
)
304305

305-
z_discrete = jax.tree.map(
306-
lambda idx, support: support[idx],
307-
z_discrete,
308-
self._support_enumerates,
309-
)
310306
z = {**z_discrete, **hmc_state.z}
311307
return MixedHMCState(z, hmc_state, rng_key, accept_prob)
312308

0 commit comments

Comments
 (0)