-
Notifications
You must be signed in to change notification settings - Fork 266
Description
Feature Summary
Recently, a couple of samplers have been proposed by the group I work with (https://arxiv.org/abs/2503.01707, https://arxiv.org/abs/2212.08549) as alternatives to NUTS HMC (see issue #1662
). Since they appear to be quite a bit faster than NUTS (at least on benchmark problems I've tried), and relatively simple, I'm interested in adding them to NumPyro, but wanted to get some advice.
Currently, implementations exist in Blackjax. In an ideal world, I'd make a new class like class AdjustedMicrocanonical(numpyro.infer.mcmc.MCMCKernel)
which basically just wraps Blackjax.
In addition, my eventual goal would be to add not just the kernel, but also the tuning scheme (which is key to good performance). I'm curious if there's a straightforward way to do that.
Motivation
While it's easy to write a model in NumPyro and extract the density, then use Blackjax for inference, we want to give users more direct access (basically for the purpose of increasing discoverability).