Skip to content

Update to fast sampling notebook #794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ examples/gallery.rst

pixi.lock


# pixi environments
.pixi
*.egg-info
396 changes: 263 additions & 133 deletions examples/samplers/fast_sampling_with_jax_and_numba.ipynb

Large diffs are not rendered by default.

136 changes: 116 additions & 20 deletions examples/samplers/fast_sampling_with_jax_and_numba.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ jupytext:
format_name: myst
format_version: 0.13
kernelspec:
display_name: pymc5recent
display_name: default
language: python
name: pymc5recent
name: python3
---

(faster_sampling_notebook)=
Expand All @@ -22,18 +22,61 @@ kernelspec:

+++

PyMC can compile its models to various execution backends through PyTensor, including:
* C
* JAX
* Numba
PyMC offers multiple sampling backends that can dramatically improve performance depending on your model size and requirements. Each backend has distinct advantages and is optimized for different use cases.

By default, PyMC is using the C backend which then gets called by the Python-based samplers.
### PyMC's Built-in Sampler

However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead.
```python
pm.sample()
```

The default PyMC sampler uses a Python-based NUTS implementation that provides maximum compatibility with all PyMC features. This sampler is always used when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using the `compile_kwargs` parameter, it still maintains Python overhead that can limit performance for large models.

### Nutpie Sampler

```python
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"})
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"})
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "pytensor"})
```

Nutpie is on the cutting-edge of PyMC sampling performance. Written in Rust, it eliminates most Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.

### NumPyro Sampler

```python
pm.sample(nuts_sampler="numpyro", nuts_sampler_kwargs={"chain_method": "parallel"})
# GPU-accelerated
pm.sample(nuts_sampler="numpyro", nuts_sampler_kwargs={"chain_method": "vectorized"})
```

NumPyro provides a mature JAX-based sampling implementation that integrates seamlessly with the broader JAX ecosystem. This sampler typically performs best with small to medium-sized models and offers excellent GPU support. NumPyro benefits from years of development within the JAX community and provides reliable performance characteristics, though it may have compilation overhead for very large models.

### BlackJAX Sampler

```python
pm.sample(nuts_sampler="blackjax")
```

For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install `numpyro` and `blackjax`. Both of them are available through conda/mamba: `mamba install -c conda-forge numpyro blackjax`.
BlackJAX offers another JAX-based sampling implementation focused on flexibility and research applications. While it provides similar capabilities to NumPyro, it's less commonly used in production environments. BlackJAX can be valuable for experimental workflows or when specific JAX-based features are required that aren't available in other samplers.

For the Numba backend, there is the [Nutpie sampler](https://github.com/pymc-devs/nutpie) written in Rust. To use this sampler you need `nutpie` installed: `mamba install -c conda-forge nutpie`.
## Performance Guidelines

Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.

**Model Size Considerations**

For small models, NumPyro typically provides the best balance of performance and reliability. The compilation overhead is minimal, and the mature JAX implementation handles these models efficiently. Larger models often benefit from Nutpie with the Numba backend, which provides excellent performance without the memory overhead sometimes associated with JAX compilation.

Large models generally perform best with either Nutpie's JAX backend or Nutpie's Numba backend. The choice between these depends on whether GPU acceleration is needed and how the model's computational graph interacts with each backend's optimization strategies.

**Variable Type Requirements**

Models containing discrete variables have no choice but to use PyMC's built-in sampler, as it's the only implementation that supports the necessary Slice and Metropolis sampling algorithms. For purely continuous models, all sampling backends are available, making performance the primary consideration.

**Computational Backend Selection**

Numba excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. JAX offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The traditional C backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.

```{code-cell} ipython3
import arviz as az
Expand All @@ -50,7 +93,7 @@ print(f"Running on PyMC v{pm.__version__}")
az.style.use("arviz-darkgrid")
```

We will use a simple probabilistic PCA model as our example.
We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.

```{code-cell} ipython3
def build_toy_dataset(N, D, K, sigma=1):
Expand Down Expand Up @@ -91,44 +134,97 @@ with pm.Model() as PPCA:
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
```

## Sampling using Python NUTS sampler
## Performance Comparison

Now let's compare the performance of different sampling backends on our PPCA model. We'll measure both compilation time and sampling speed.

### 1. PyMC Default Sampler (Python NUTS)

```{code-cell} ipython3
%%time
with PPCA:
idata_pymc = pm.sample()
```

## Sampling using NumPyro JAX NUTS sampler
### 2. Nutpie with Numba Backend

```{code-cell} ipython3
%%time
with PPCA:
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
idata_nutpie_numba = pm.sample(
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"}, progressbar=False
)
```

## Sampling using BlackJAX NUTS sampler
### 3. Nutpie with JAX Backend

```{code-cell} ipython3
%%time
with PPCA:
idata_blackjax = pm.sample(nuts_sampler="blackjax")
idata_nutpie_jax = pm.sample(
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"}, progressbar=False
)
```

## Sampling using Nutpie Rust NUTS sampler
### 4. NumPyro Sampler

```{code-cell} ipython3
%%time
with PPCA:
idata_nutpie = pm.sample(nuts_sampler="nutpie")
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
```

## Installation Requirements

To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. `conda install nutpie`). For JAX-based workflows, NumPyro provides mature functionality and is installed with the `numpyro` package. BlackJAX offers an alternative JAX implementation and is available in the `blackjax` package.

+++

## Special Cases and Advanced Usage

### Using PyMC's Built-in Sampler with Different Backends

In certain scenarios, you may need to use PyMC's Python-based sampler while still benefiting from faster computational backends. This situation commonly arises when working with models that contain discrete variables, which require PyMC's specialized sampling algorithms. Even in these cases, you can significantly improve performance by compiling the model's computational graph to more efficient backends.

The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The `fast_run` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The `numba` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The `jax` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.

```{code-cell} ipython3
with PPCA:
idata_c = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "fast_run"})

# with PPCA:
# idata_pymc_numba = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "numba"})

# with PPCA:
# idata_pymc_jax = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "jax"})
```

The above examples are commented out to avoid redundant sampling in this demonstration notebook. In practice, you would uncomment and run the configuration that matches your model's requirements. These compilation modes allow you to access faster computational backends even when you must use PyMC's Python-based sampler for compatibility reasons.

+++

### Models with Discrete Variables

When working with models that contain discrete variables, you have no choice but to use PyMC's built-in sampler. This is because discrete variables require specialized sampling algorithms like Slice sampling or Metropolis-Hastings that are only available in PyMC's Python implementation. The example below demonstrates a typical scenario where this constraint applies.

```{code-cell} ipython3
with pm.Model() as discrete_model:
cluster = pm.Categorical("cluster", p=[0.3, 0.7], shape=100)
mu = pm.Normal("mu", 0, 1, shape=2)
sigma = pm.HalfNormal("sigma", 1, shape=2)
obs = pm.Normal("obs", mu=mu[cluster], sigma=sigma[cluster], observed=rng.normal(0, 1, 100))

trace_discrete = pm.sample()
```

## Authors
Authored by Thomas Wiecki in July 2023

- Originally authored by Thomas Wiecki in July 2023
- Substantially updated and expanded by Chris Fonnesbeck in May 2025

```{code-cell} ipython3
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
```

:::{include} ../page_footer.md
Expand Down
34 changes: 10 additions & 24 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -1,35 +1,21 @@
[project]
[workspace]
authors = ["Chris Fonnesbeck <[email protected]>"]
channels = ["conda-forge"]
description = "Add a short description here"
name = "pymc-examples"
platforms = ["linux-64"]
version = "0.1.0"

[tasks]

[dependencies]
python = ">=3.12.5,<4"
pymc = ">=5.16.2,<6"
jupyter = ">=1.1.1,<2"
pymc = ">=5.22.0,<6"
nutpie = ">=0.14.3,<0.15"
numpyro = ">=0.18.0,<0.19"
numba = ">=0.61.2,<0.62"
ipywidgets = ">=8.1.7,<9"
arviz = ">=0.21.0,<0.22"
matplotlib = ">=3.10.3,<4"
python = ">=3.12.10,<3.13"
ipykernel = ">=6.29.5,<7"
ipywidgets = ">=8.1.5,<9"
numpy = ">=1.26.4,<2"
arviz = ">=0.19.0,<0.20"
numpyro = ">=0.15.2,<0.16"
seaborn = ">=0.13.2,<0.14"
matplotlib = ">=3.9.2,<4"
pandas = ">=2.2.2,<3"
polars = ">=1.6.0,<2"
esbonio = ">=0.16.4,<0.17"
blackjax = ">=1.2.4,<2"
watermark = ">=2.5.0,<3"
nutpie = ">=0.13.2,<0.14"
numba = ">=0.60.0,<0.61"
scikit-learn = ">=1.5.2,<2"
blackjax = ">=1.2.3,<2"
networkx = ">=3.4.2,<4"
bokeh = ">=3.7.2,<4"

[pypi-dependencies]
pymc-experimental = ">=0.1.2, <0.2"
pymc-extras = ">=0.2.0, <0.3"