diff --git a/docs/source/how_to/ConversionGuideNumPyro.ipynb b/docs/source/how_to/ConversionGuideNumPyro.ipynb new file mode 100644 index 0000000..1a39c47 --- /dev/null +++ b/docs/source/how_to/ConversionGuideNumPyro.ipynb @@ -0,0 +1,10237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a641832b", + "metadata": {}, + "source": [ + "(numpyro_conversion)=\n", + "# Converting NumPyro objects to DataTree\n", + "\n", + "{class}`~datatree.DataTree` is the data format ArviZ relies on.\n", + "\n", + "This page covers multiple ways to generate a `DataTree` from NumPyro MCMC and SVI objects." + ] + }, + { + "cell_type": "markdown", + "id": "279a434d", + "metadata": {}, + "source": [ + "```{seealso}\n", + "\n", + "- Conversion from Python, numpy or pandas objects\n", + "- {ref}`working_with_InferenceData` for an overview of `InferenceData` and its role within ArviZ.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "b702e7fd", + "metadata": {}, + "source": [ + "We will start by importing the required packages and defining the model. The famous 8 school model." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "87f7958f", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz_base as az\n", + "import numpy as np\n", + "\n", + "import numpyro\n", + "import numpyro.distributions as dist\n", + "from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive\n", + "from jax import random\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7e9a05da", + "metadata": {}, + "outputs": [], + "source": [ + "J = 8\n", + "y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", + "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "00357127", + "metadata": {}, + "outputs": [], + "source": [ + "def eight_schools_model(J, sigma, y=None):\n", + " mu = numpyro.sample(\"mu\", dist.Normal(0, 5))\n", + " tau = numpyro.sample(\"tau\", dist.HalfCauchy(5))\n", + " with numpyro.plate(\"J\", J):\n", + " eta = numpyro.sample(\"eta\", dist.Normal(0, 1))\n", + " theta = numpyro.deterministic(\"theta\", mu + tau * eta)\n", + " return numpyro.sample(\"obs\", dist.Normal(theta, sigma), obs=y)\n", + " \n", + "\n", + "def eight_schools_custom_guide(J, sigma, y=None):\n", + "\n", + " # Variational parameters for mu\n", + " mu_loc = numpyro.param(\"mu_loc\", 0.0)\n", + " mu_scale = numpyro.param(\"mu_scale\", 1.0, constraint=dist.constraints.positive)\n", + " mu = numpyro.sample(\"mu\", dist.Normal(mu_loc, mu_scale))\n", + "\n", + " # Variational parameters for tau (positive support)\n", + " tau_loc = numpyro.param(\"tau_loc\", 1.0)\n", + " tau_scale = numpyro.param(\"tau_scale\", 0.5, constraint=dist.constraints.positive)\n", + " tau = numpyro.sample(\"tau\", dist.LogNormal(jnp.log(tau_loc), tau_scale))\n", + "\n", + " # Variational parameters for eta\n", + " eta_loc = numpyro.param(\"eta_loc\", jnp.zeros(J))\n", + " eta_scale = numpyro.param(\"eta_scale\", jnp.ones(J), constraint=dist.constraints.positive)\n", + " with numpyro.plate(\"J\", J):\n", + " eta = numpyro.sample(\"eta\", dist.Normal(eta_loc, eta_scale))\n", + "\n", + " # Deterministic transform\n", + " numpyro.deterministic(\"theta\", mu + tau * eta)" + ] + }, + { + "cell_type": "markdown", + "id": "de2014a9", + "metadata": {}, + "source": [ + "## Convert from MCMC\n", + "\n", + "This first example shows conversion from MCMC" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "83ed9d9b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.\n", + " mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 3284.68it/s, 7 steps of size 3.56e-01. acc. prob=0.89] \n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 8009.51it/s, 7 steps of size 4.66e-01. acc. prob=0.88]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7861.79it/s, 7 steps of size 5.04e-01. acc. prob=0.84]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7739.55it/s, 7 steps of size 3.94e-01. acc. prob=0.90]\n" + ] + }, + { + "data": { + "text/html": [ + "
<xarray.DatasetView> Size: 0B\n", + "Dimensions: ()\n", + "Data variables:\n", + " *empty*
<xarray.DatasetView> Size: 0B\n", + "Dimensions: ()\n", + "Data variables:\n", + " *empty*
<xarray.DatasetView> Size: 0B\n", + "Dimensions: ()\n", + "Data variables:\n", + " *empty*
<xarray.DatasetView> Size: 0B\n", + "Dimensions: ()\n", + "Data variables:\n", + " *empty*