Skip to content

Adding new samplers to NumPyroΒ #2035

@reubenharry

Description

@reubenharry

Feature Summary

Recently, a couple of samplers have been proposed by the group I work with (https://arxiv.org/abs/2503.01707, https://arxiv.org/abs/2212.08549) as alternatives to NUTS HMC (see issue #1662
). Since they appear to be quite a bit faster than NUTS (at least on benchmark problems I've tried), and relatively simple, I'm interested in adding them to NumPyro, but wanted to get some advice.

Currently, implementations exist in Blackjax. In an ideal world, I'd make a new class like class AdjustedMicrocanonical(numpyro.infer.mcmc.MCMCKernel) which basically just wraps Blackjax.

In addition, my eventual goal would be to add not just the kernel, but also the tuning scheme (which is key to good performance). I'm curious if there's a straightforward way to do that.

Motivation

While it's easy to write a model in NumPyro and extract the density, then use Blackjax for inference, we want to give users more direct access (basically for the purpose of increasing discoverability).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions