diff --git a/docs/source/how_to/ConversionGuideNumPyro.ipynb b/docs/source/how_to/ConversionGuideNumPyro.ipynb new file mode 100644 index 0000000..1a39c47 --- /dev/null +++ b/docs/source/how_to/ConversionGuideNumPyro.ipynb @@ -0,0 +1,10237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a641832b", + "metadata": {}, + "source": [ + "(numpyro_conversion)=\n", + "# Converting NumPyro objects to DataTree\n", + "\n", + "{class}`~datatree.DataTree` is the data format ArviZ relies on.\n", + "\n", + "This page covers multiple ways to generate a `DataTree` from NumPyro MCMC and SVI objects." + ] + }, + { + "cell_type": "markdown", + "id": "279a434d", + "metadata": {}, + "source": [ + "```{seealso}\n", + "\n", + "- Conversion from Python, numpy or pandas objects\n", + "- {ref}`working_with_InferenceData` for an overview of `InferenceData` and its role within ArviZ.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "b702e7fd", + "metadata": {}, + "source": [ + "We will start by importing the required packages and defining the model. The famous 8 school model." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "87f7958f", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz_base as az\n", + "import numpy as np\n", + "\n", + "import numpyro\n", + "import numpyro.distributions as dist\n", + "from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive\n", + "from jax import random\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7e9a05da", + "metadata": {}, + "outputs": [], + "source": [ + "J = 8\n", + "y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", + "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "00357127", + "metadata": {}, + "outputs": [], + "source": [ + "def eight_schools_model(J, sigma, y=None):\n", + " mu = numpyro.sample(\"mu\", dist.Normal(0, 5))\n", + " tau = numpyro.sample(\"tau\", dist.HalfCauchy(5))\n", + " with numpyro.plate(\"J\", J):\n", + " eta = numpyro.sample(\"eta\", dist.Normal(0, 1))\n", + " theta = numpyro.deterministic(\"theta\", mu + tau * eta)\n", + " return numpyro.sample(\"obs\", dist.Normal(theta, sigma), obs=y)\n", + " \n", + "\n", + "def eight_schools_custom_guide(J, sigma, y=None):\n", + "\n", + " # Variational parameters for mu\n", + " mu_loc = numpyro.param(\"mu_loc\", 0.0)\n", + " mu_scale = numpyro.param(\"mu_scale\", 1.0, constraint=dist.constraints.positive)\n", + " mu = numpyro.sample(\"mu\", dist.Normal(mu_loc, mu_scale))\n", + "\n", + " # Variational parameters for tau (positive support)\n", + " tau_loc = numpyro.param(\"tau_loc\", 1.0)\n", + " tau_scale = numpyro.param(\"tau_scale\", 0.5, constraint=dist.constraints.positive)\n", + " tau = numpyro.sample(\"tau\", dist.LogNormal(jnp.log(tau_loc), tau_scale))\n", + "\n", + " # Variational parameters for eta\n", + " eta_loc = numpyro.param(\"eta_loc\", jnp.zeros(J))\n", + " eta_scale = numpyro.param(\"eta_scale\", jnp.ones(J), constraint=dist.constraints.positive)\n", + " with numpyro.plate(\"J\", J):\n", + " eta = numpyro.sample(\"eta\", dist.Normal(eta_loc, eta_scale))\n", + "\n", + " # Deterministic transform\n", + " numpyro.deterministic(\"theta\", mu + tau * eta)" + ] + }, + { + "cell_type": "markdown", + "id": "de2014a9", + "metadata": {}, + "source": [ + "## Convert from MCMC\n", + "\n", + "This first example shows conversion from MCMC" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "83ed9d9b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.\n", + " mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 3284.68it/s, 7 steps of size 3.56e-01. acc. prob=0.89] \n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 8009.51it/s, 7 steps of size 4.66e-01. acc. prob=0.88]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7861.79it/s, 7 steps of size 5.04e-01. acc. prob=0.84]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7739.55it/s, 7 steps of size 3.94e-01. acc. prob=0.90]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DatasetView> Size: 0B\n",
+       "Dimensions:  ()\n",
+       "Data variables:\n",
+       "    *empty*
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior\n", + "│ Dimensions: (chain: 4, draw: 1000, J: 8)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ eta (chain, draw, J) float32 128kB 1.691 1.086 ... -0.05226 -0.7434\n", + "│ mu (chain, draw) float32 16kB -3.026 0.3796 8.286 ... 5.022 3.53 1.153\n", + "│ tau (chain, draw) float32 16kB 10.34 4.316 0.6902 ... 1.996 1.01 12.41\n", + "│ theta (chain, draw, J) float32 128kB 14.45 8.198 -6.399 ... 0.5049 -8.071\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:26.956151+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /sample_stats\n", + "│ Dimensions: (chain: 4, draw: 1000)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ Data variables:\n", + "│ diverging (chain, draw) bool 4kB ...\n", + "│ energy (chain, draw) float32 16kB ...\n", + "│ n_steps (chain, draw) int32 16kB ...\n", + "│ tree_depth (chain, draw) int64 32kB 4 4 4 3 4 3 3 3 3 ... 3 4 4 4 3 4 3 3 3\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:26.957867+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (chain: 4, draw: 1000, J: 8)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ obs (chain, draw, J) float32 128kB 28.67 -1.249 ... 9.282 0.07767\n", + "│ theta (chain, draw, J) float32 128kB 14.45 8.198 -6.399 ... 0.5049 -8.071\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:26.958921+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "└── Group: /observed_data\n", + " Dimensions: (J: 8)\n", + " Coordinates:\n", + " * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + " Data variables:\n", + " obs (J) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", + " Attributes:\n", + " created_at: 2025-10-29T16:01:26.959668+00:00\n", + " creation_library: ArviZ\n", + " creation_library_version: 0.7.0.dev0\n", + " creation_library_language: Python\n", + " inference_library: numpyro\n", + " inference_library_version: 0.19.0" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# fit with MCMC\n", + "nuts = NUTS(eight_schools_model)\n", + "mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)\n", + "mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=(\"num_steps\", \"energy\"),)\n", + "\n", + "# sample the posterior predictive\n", + "predictive = Predictive(eight_schools_model, mcmc.get_samples())\n", + "samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)\n", + "\n", + "# Convert to MCMC\n", + "idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)\n", + "idata_mcmc" + ] + }, + { + "cell_type": "markdown", + "id": "d4fd92fa", + "metadata": {}, + "source": [ + "## Convert from SVI with Autoguide" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a9ef6964", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10000/10000 [00:00<00:00, 10095.92it/s, init loss: 53.6608, avg. loss [9501-10000]: 31.6204]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DatasetView> Size: 0B\n",
+       "Dimensions:  ()\n",
+       "Data variables:\n",
+       "    *empty*
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior\n", + "│ Dimensions: (sample: 4000, J: 8)\n", + "│ Coordinates:\n", + "│ * sample (sample) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ eta (sample, J) float32 128kB -1.212 0.8609 1.496 ... -1.611 -0.6918\n", + "│ mu (sample) float32 16kB -3.658 -2.307 5.14 ... 2.604 0.9337 3.985\n", + "│ tau (sample) float32 16kB 0.9026 1.751 1.003 ... 0.8077 3.972 2.689\n", + "│ theta (sample, J) float32 128kB -4.752 -2.881 -2.308 ... -0.3478 2.125\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:28.480833+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /sample_stats\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:28.481750+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (sample: 4000, J: 8)\n", + "│ Coordinates:\n", + "│ * sample (sample) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ obs (sample, J) float32 128kB 4.552 -3.569 -26.3 ... -1.14 -0.9699\n", + "│ theta (sample, J) float32 128kB 2.909 2.912 1.596 ... -3.738 -2.447\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:28.482498+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "└── Group: /observed_data\n", + " Dimensions: (J: 8)\n", + " Coordinates:\n", + " * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + " Data variables:\n", + " obs (J) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", + " Attributes:\n", + " created_at: 2025-10-29T16:01:28.483124+00:00\n", + " creation_library: ArviZ\n", + " creation_library_version: 0.7.0.dev0\n", + " creation_library_language: Python\n", + " inference_library: numpyro\n", + " inference_library_version: 0.19.0" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))\n", + "svi = SVI(\n", + " eight_schools_model, \n", + " guide=eight_schools_guide,\n", + " optim=numpyro.optim.Adam(0.01),\n", + " loss = Trace_ELBO()\n", + ")\n", + "svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)\n", + "\n", + "# sample the posterior predictive\n", + "predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)\n", + "samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)\n", + "\n", + "\n", + "idata_svi = az.from_numpyro_svi(\n", + " svi,\n", + " svi_result=svi_result,\n", + " model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs\n", + " num_samples = 4000, # number of samples to draw in the posterior\n", + " posterior_predictive=samples_predictive_svi\n", + ")\n", + "idata_svi" + ] + }, + { + "cell_type": "markdown", + "id": "bd72155b", + "metadata": {}, + "source": [ + "## Converting from SVI with a custom guide function" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3e4fffc6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10000/10000 [00:00<00:00, 10246.71it/s, init loss: 34.9525, avg. loss [9501-10000]: 31.6279]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DatasetView> Size: 0B\n",
+       "Dimensions:  ()\n",
+       "Data variables:\n",
+       "    *empty*
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior\n", + "│ Dimensions: (sample: 4000, J: 8)\n", + "│ Coordinates:\n", + "│ * sample (sample) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ eta (sample, J) float32 128kB -2.211 0.1604 1.648 ... -2.492 0.1236\n", + "│ mu (sample) float32 16kB -0.3081 -0.6043 6.962 ... 2.391 6.263 7.293\n", + "│ tau (sample) float32 16kB 0.7794 3.565 4.964 ... 1.445 5.136 2.295\n", + "│ theta (sample, J) float32 128kB -2.032 -0.1831 0.9759 ... 1.575 7.576\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:30.153743+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /sample_stats\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:30.154659+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (sample: 4000, J: 8)\n", + "│ Coordinates:\n", + "│ * sample (sample) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ obs (sample, J) float32 128kB 0.8301 -7.256 -29.28 ... -0.2221 -0.507\n", + "│ theta (sample, J) float32 128kB -0.813 -0.7758 -1.391 ... -2.82 -1.984\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:30.155301+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "└── Group: /observed_data\n", + " Dimensions: (J: 8)\n", + " Coordinates:\n", + " * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + " Data variables:\n", + " obs (J) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", + " Attributes:\n", + " created_at: 2025-10-29T16:01:30.155909+00:00\n", + " creation_library: ArviZ\n", + " creation_library_version: 0.7.0.dev0\n", + " creation_library_language: Python\n", + " inference_library: numpyro\n", + " inference_library_version: 0.19.0" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "svi_custom_guide = SVI(\n", + " eight_schools_model, \n", + " guide=eight_schools_custom_guide,\n", + " optim=numpyro.optim.Adam(0.01),\n", + " loss = Trace_ELBO()\n", + ")\n", + "svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)\n", + "\n", + "# sample the posterior predictive\n", + "predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)\n", + "samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)\n", + "\n", + "idata_svi_custom_guide = az.from_numpyro_svi(\n", + " svi_custom_guide,\n", + " svi_result=svi_custom_guide_result,\n", + " model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs\n", + " num_samples = 4000, # number of samples to draw in the posterior\n", + " posterior_predictive=samples_predictive_svi_custom\n", + ")\n", + "idata_svi_custom_guide" + ] + }, + { + "cell_type": "markdown", + "id": "dd794c16", + "metadata": {}, + "source": [ + "## Automatically Labelling Event Dims\n", + "\n", + "NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add `infer={\"event_dims\": dim_labels}` to the `numpyro.sample` statement as shown below:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "62899dda", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/306760900.py:17: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.\n", + " mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 3119.59it/s, 15 steps of size 2.12e-01. acc. prob=0.82]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7054.28it/s, 15 steps of size 2.12e-01. acc. prob=0.86]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 7063.72it/s, 15 steps of size 2.82e-01. acc. prob=0.83]\n", + "sample: 100%|██████████| 2000/2000 [00:00<00:00, 6924.13it/s, 31 steps of size 1.85e-01. acc. prob=0.91]\n" + ] + } + ], + "source": [ + "def eight_schools_model_zsn(J, sigma, y=None):\n", + " mu = numpyro.sample(\"mu\", dist.Normal(0, 5))\n", + " tau = numpyro.sample(\"tau\", dist.HalfCauchy(5))\n", + " eta = numpyro.sample(\n", + " \"eta\", \n", + " dist.ZeroSumNormal(tau, event_shape=(J,)),\n", + " # note: this allows arviz to infer the event dimension labels\n", + " infer={\"event_dims\":[\"J\"]}\n", + " )\n", + " with numpyro.plate(\"J\", J):\n", + " theta = numpyro.deterministic(\"theta\", mu + eta)\n", + " return numpyro.sample(\"obs\", dist.Normal(theta, sigma), obs=y)\n", + "\n", + "\n", + "# fit with MCMC\n", + "nuts = NUTS(eight_schools_model_zsn)\n", + "mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)\n", + "mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=(\"num_steps\", \"energy\"),)\n", + "\n", + "\n", + "# sample the posterior predictive\n", + "predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())\n", + "samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)\n", + "\n", + "# Convert to MCMC\n", + "idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)" + ] + }, + { + "cell_type": "markdown", + "id": "3cc0ad16", + "metadata": {}, + "source": [ + "Notice that eta is labelled appropriately with `J`" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "02b381a5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DatasetView> Size: 0B\n",
+       "Dimensions:  ()\n",
+       "Data variables:\n",
+       "    *empty*
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior\n", + "│ Dimensions: (chain: 4, draw: 1000, J: 8)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ eta (chain, draw, J) float32 128kB 0.2797 2.404 ... -2.467 -2.548\n", + "│ mu (chain, draw) float32 16kB 6.454 4.101 3.804 ... 2.073 6.419 2.331\n", + "│ tau (chain, draw) float32 16kB 5.784 10.78 6.886 ... 1.588 2.861 1.91\n", + "│ theta (chain, draw, J) float32 128kB 6.734 8.858 ... -0.1354 -0.2169\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:31.816711+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /sample_stats\n", + "│ Dimensions: (chain: 4, draw: 1000)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ Data variables:\n", + "│ diverging (chain, draw) bool 4kB ...\n", + "│ energy (chain, draw) float32 16kB ...\n", + "│ n_steps (chain, draw) int32 16kB ...\n", + "│ tree_depth (chain, draw) int64 32kB 3 4 4 4 4 4 4 4 4 ... 5 4 4 4 4 3 3 5 5\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:31.818312+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (chain: 4, draw: 1000, J: 8)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", + "│ * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + "│ Data variables:\n", + "│ obs (chain, draw, J) float32 128kB 22.29 10.91 -11.63 ... 6.396 5.612\n", + "│ theta (chain, draw, J) float32 128kB 8.072 20.36 2.025 ... -2.381 -2.537\n", + "│ Attributes:\n", + "│ created_at: 2025-10-29T16:01:31.819355+00:00\n", + "│ creation_library: ArviZ\n", + "│ creation_library_version: 0.7.0.dev0\n", + "│ creation_library_language: Python\n", + "│ inference_library: numpyro\n", + "│ inference_library_version: 0.19.0\n", + "└── Group: /observed_data\n", + " Dimensions: (J: 8)\n", + " Coordinates:\n", + " * J (J) int64 64B 0 1 2 3 4 5 6 7\n", + " Data variables:\n", + " obs (J) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", + " Attributes:\n", + " created_at: 2025-10-29T16:01:31.820027+00:00\n", + " creation_library: ArviZ\n", + " creation_library_version: 0.7.0.dev0\n", + " creation_library_language: Python\n", + " inference_library: numpyro\n", + " inference_library_version: 0.19.0" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idata_mcmc2" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ac91b263", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The watermark extension is already loaded. To reload it, use:\n", + " %reload_ext watermark\n", + "Last updated: Wed Oct 29 2025\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.12.10\n", + "IPython version : 9.4.0\n", + "\n", + "arviz_base: 0.7.0.dev0\n", + "numpyro : 0.19.0\n", + "jax : 0.6.2\n", + "numpy : 2.3.2\n", + "\n", + "Watermark: 2.5.0\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "arviz-dev312", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/index.md b/docs/source/index.md index ced4019..8f4b4e8 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -67,6 +67,7 @@ Thus, to install all user facing optional dependencies you should use `arviz-bas tutorial/WorkingWithDataTree tutorial/label_guide how_to/ConversionGuideEmcee +how_to/ConversionGuideNumPyro ArviZ in Context ::: diff --git a/external_tests/helpers.py b/external_tests/helpers.py index 936f3e5..231fcb4 100644 --- a/external_tests/helpers.py +++ b/external_tests/helpers.py @@ -109,8 +109,30 @@ def _numpyro_noncentered_model(J, sigma, y=None): return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y) +def _numpyro_noncentered_guide(J, sigma, y=None): + import jax + import numpyro + import numpyro.distributions as dist + + # Variational parameters for mu + mu_loc = numpyro.param("mu_loc", 0.0) + mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive) + numpyro.sample("mu", dist.Normal(mu_loc, mu_scale)) + + # Variational parameters for tau (positive support) + tau_loc = numpyro.param("tau_loc", 1.0) + tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive) + numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale)) + + # Variational parameters for eta + eta_loc = numpyro.param("eta_loc", jax.numpy.zeros(J)) + eta_scale = numpyro.param("eta_scale", jax.numpy.ones(J), constraint=dist.constraints.positive) + with numpyro.plate("J", J): + numpyro.sample("eta", dist.Normal(eta_loc, eta_scale)) + + def numpyro_schools_model(data, draws, chains): - """Centered eight schools implementation in NumPyro.""" + """Non-centered eight schools implementation in NumPyro.""" from jax.random import PRNGKey from numpyro.infer import MCMC, NUTS @@ -133,6 +155,35 @@ def numpyro_schools_model(data, draws, chains): return mcmc +def numpyro_schools_model_svi(data, draws, chains): + """Non-centered eight schools implementation in NumPyro.""" + from jax.random import PRNGKey + from numpyro.infer import SVI, Trace_ELBO, init_to_sample + from numpyro.infer.autoguide import AutoNormal + from numpyro.optim import Adam + + guide = AutoNormal(_numpyro_noncentered_model, init_loc_fn=init_to_sample()) + svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO()) + svi_result = svi.run(PRNGKey(0), 4000, **data) + return {"svi": svi, "svi_result": svi_result, "model_kwargs": data} + + +def numpyro_schools_model_svi_custom_guide(data, draws, chains): + """Non-centered eight schools implementation in NumPyro.""" + from jax.random import PRNGKey + from numpyro.infer import SVI, Trace_ELBO + from numpyro.optim import Adam + + guide = _numpyro_noncentered_guide + svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO()) + svi_result = svi.run(PRNGKey(0), 4000, **data) + return { + "svi": svi, + "svi_result": svi_result, + "model_kwargs": data, + } + + def pystan_noncentered_schools(data, draws, chains): """Non-centered eight schools implementation for pystan.""" schools_code = """ @@ -188,10 +239,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): """Load pystan, emcee, and pyro models from pickle.""" here = os.path.dirname(os.path.abspath(__file__)) supported = ( - # ("pystan", pystan_noncentered_schools), - ("emcee", emcee_schools_model), - # ("pyro", pyro_noncentered_schools), - ("numpyro", numpyro_schools_model), + # ("pystan", pystan_noncentered_schools, None), + ("emcee", emcee_schools_model, None), + # ("pyro", pyro_noncentered_schools, None), + ("numpyro", numpyro_schools_model, None), + ("numpyro", numpyro_schools_model_svi, "numpyro_svi"), + ("numpyro", numpyro_schools_model_svi_custom_guide, "numpyro_svi_custom_guide"), ) data_directory = os.path.join(here, "saved_models") if not os.path.isdir(data_directory): @@ -201,7 +254,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): if isinstance(libs, str): libs = [libs] - for library_name, func in supported: + for library_name, func, addl_model_key in supported: + model_key = addl_model_key or library_name if libs is not None and library_name not in libs: continue library = library_handle(library_name) @@ -214,7 +268,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): py_version = sys.version_info fname = ( - f"{py_version.major}.{py_version.minor}_{library.__name__}_{library.__version__}" + f"{py_version.major}.{py_version.minor}_{model_key}_{library.__version__}" f"_{sys.platform}_{draws}_{chains}.pkl.gzip" ) @@ -225,11 +279,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): _log.info("Generating and caching %s", fname) cloudpickle.dump(func(eight_schools_data, draws, chains), buff) except AttributeError as err: - raise AttributeError(f"Failed caching {library_name}") from err + raise AttributeError(f"Failed caching {model_key}") from err with gzip.open(path, "rb") as buff: _log.info("Loading %s from cache", fname) - models[library.__name__] = cloudpickle.load(buff) + models[model_key] = cloudpickle.load(buff) return models diff --git a/external_tests/test_numpyro.py b/external_tests/test_numpyro.py index d250fd8..079f317 100644 --- a/external_tests/test_numpyro.py +++ b/external_tests/test_numpyro.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from arviz_base.io_numpyro import from_numpyro +from arviz_base.io_numpyro import SVIWrapper, from_numpyro, from_numpyro_svi from arviz_base.testing import check_multiple_attrs from .helpers import importorskip, load_cached_models @@ -14,14 +14,15 @@ PRNGKey = jax.random.PRNGKey numpyro = importorskip("numpyro") Predictive = numpyro.infer.Predictive +autoguide = numpyro.infer.autoguide numpyro.set_host_device_count(2) class TestDataNumPyro: - @pytest.fixture(scope="class") - def data(self, eight_schools_params, draws, chains): + @pytest.fixture(scope="class", params=["numpyro", "numpyro_svi", "numpyro_svi_custom_guide"]) + def data(self, request, eight_schools_params, draws, chains): class Data: - obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"] + obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")[request.param] return Data @@ -36,8 +37,9 @@ def predictions_params(self): @pytest.fixture(scope="class") def predictions_data(self, data, predictions_params): """Generate predictions for predictions_params""" - posterior_samples = data.obj.get_samples() - model = data.obj.sampler.model + posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj + posterior_samples = posterior.get_samples() + model = posterior.sampler.model predictions = Predictive(model, posterior_samples)( PRNGKey(2), predictions_params["J"], predictions_params["sigma"] ) @@ -46,8 +48,17 @@ def predictions_data(self, data, predictions_params): def get_inference_data( self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False ): - posterior_samples = data.obj.get_samples() - model = data.obj.sampler.model + if isinstance(data.obj, dict): # SVI cached data obj is a tuple + posterior = SVIWrapper(**data.obj) + from_numpyro_func = from_numpyro_svi + posterior_kwarg = data.obj + else: # regular MCMC + posterior = data.obj + from_numpyro_func = from_numpyro + posterior_kwarg = {"posterior": posterior} + + posterior_samples = posterior.get_samples() + model = posterior.sampler.model posterior_predictive = Predictive(model, posterior_samples)( PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"] ) @@ -60,8 +71,9 @@ def get_inference_data( dims = pred_dims = None predictions = predictions_data - return from_numpyro( - posterior=data.obj, + + return from_numpyro_func( + **posterior_kwarg, prior=prior, posterior_predictive=posterior_predictive, predictions=predictions, @@ -74,17 +86,18 @@ def get_inference_data( ) def test_inference_data_namedtuple(self, data): - samples = data.obj.get_samples() + posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj + samples = posterior.get_samples() Samples = namedtuple("Samples", samples) data_namedtuple = Samples(**samples) - _old_fn = data.obj.get_samples - data.obj.get_samples = lambda *args, **kwargs: data_namedtuple + _old_fn = posterior.get_samples + posterior.get_samples = lambda *args, **kwargs: data_namedtuple inference_data = from_numpyro( - posterior=data.obj, + posterior=posterior, dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain ) - assert isinstance(data.obj.get_samples(), Samples) - data.obj.get_samples = _old_fn + assert isinstance(posterior.get_samples(), Samples) + posterior.get_samples = _old_fn for key in samples: assert key in inference_data.posterior @@ -101,6 +114,8 @@ def test_inference_data(self, data, eight_schools_params, predictions_data, pred "prior_predictive": ["obs"], "observed_data": ["obs"], } + if isinstance(data.obj, dict): # if its SVI, drop sample_stats check + test_dict.pop("sample_stats") fails = check_multiple_attrs(test_dict, inference_data) assert not fails @@ -113,8 +128,9 @@ def test_inference_data(self, data, eight_schools_params, predictions_data, pred def test_inference_data_no_posterior( self, data, eight_schools_params, predictions_data, predictions_params ): - posterior_samples = data.obj.get_samples() - model = data.obj.sampler.model + posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj + posterior_samples = posterior.get_samples() + model = posterior.sampler.model posterior_predictive = Predictive(model, posterior_samples)( PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"] ) @@ -161,11 +177,15 @@ def test_inference_data_no_posterior( assert not fails, f"prior and posterior_predictive: {fails}" def test_inference_data_only_posterior(self, data): - idata = from_numpyro(data.obj) + kwargs = data.obj if isinstance(data.obj, dict) else {"posterior": data.obj} + from_numpyro_func = from_numpyro_svi if isinstance(data.obj, dict) else from_numpyro + idata = from_numpyro_func(**kwargs) test_dict = { "posterior": ["mu", "tau", "eta"], "sample_stats": ["diverging"], } + if isinstance(data.obj, dict): + test_dict.pop("sample_stats") fails = check_multiple_attrs(test_dict, idata) assert not fails @@ -282,28 +302,59 @@ def model(): inference_data = from_numpyro(mcmc) assert inference_data.observed_data - def test_mcmc_infer_dims(self): + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_infer_dims(self, svi, guide_fn): + import jax.numpy as jnp import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): # note: group2 gets assigned dim=-1 and group1 is assigned dim=-2 with numpyro.plate("group2", 5), numpyro.plate("group1", 10): _ = numpyro.sample("param", dist.Normal(0, 1)) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - inference_data = from_numpyro( - mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)} + def guide(): + loc = numpyro.param("param_loc", jnp.zeros((10, 5))) + scale = numpyro.param( + "param_scale", jnp.ones((10, 5)), constraint=dist.constraints.positive + ) + with numpyro.plate("group2", 5), numpyro.plate("group1", 10): + numpyro.sample("param", dist.Normal(loc, scale)) + + if guide_fn == "custom": + guide_fn = guide + + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + sample_dims = ("sample",) if svi else ("chain", "draw") + + inference_data = from_numpyro_func( + **result, coords={"group1": np.arange(10), "group2": np.arange(5)} ) - assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2") + assert inference_data.posterior.param.dims == sample_dims + ("group1", "group2") assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2")) - def test_mcmc_infer_unsorted_dims(self): + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_infer_unsorted_dims(self, svi, guide_fn): + import jax.numpy as jnp import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): group1_plate = numpyro.plate("group1", 10, dim=-1) @@ -314,49 +365,113 @@ def model(): with group2_plate, group1_plate: _ = numpyro.sample("param", dist.Normal(0, 1)) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - inference_data = from_numpyro( - mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)} + def guide(): + loc = numpyro.param("param_loc", jnp.zeros((5, 10))) + scale = numpyro.param( + "param_scale", jnp.ones((5, 10)), constraint=dist.constraints.positive + ) + group1_plate = numpyro.plate("group1", 10, dim=-1) + group2_plate = numpyro.plate("group2", 5, dim=-2) + with group2_plate, group1_plate: + numpyro.sample("param", dist.Normal(loc, scale)) + + if guide_fn == "custom": + guide_fn = guide + + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + sample_dims = ("sample",) if svi else ("chain", "draw") + + inference_data = from_numpyro_func( + **result, coords={"group1": np.arange(10), "group2": np.arange(5)} ) - assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1") + assert inference_data.posterior.param.dims == sample_dims + ("group2", "group1") assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2")) - def test_mcmc_infer_dims_no_coords(self): + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_infer_dims_no_coords(self, svi, guide_fn): + import jax.numpy as jnp import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): with numpyro.plate("group", 5): _ = numpyro.sample("param", dist.Normal(0, 1)) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - inference_data = from_numpyro(mcmc) - assert inference_data.posterior.param.dims == ("chain", "draw", "group") - - def test_mcmc_event_dims(self): + def guide(): + loc = numpyro.param("param_loc", jnp.zeros(5)) + scale = numpyro.param("param_scale", jnp.ones(5), constraint=dist.constraints.positive) + with numpyro.plate("group", 5): + numpyro.sample("param", dist.Normal(loc, scale)) + + if guide_fn == "custom": + guide_fn = guide + + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + sample_dims = ("sample",) if svi else ("chain", "draw") + + inference_data = from_numpyro_func(**result) + assert inference_data.posterior.param.dims == sample_dims + ("group",) + + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_event_dims(self, svi, guide_fn): import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): _ = numpyro.sample( "gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]} ) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)}) - assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups") + def guide(): + scale = numpyro.param( + "gamma_scale", + 1.0, + constraint=dist.constraints.positive, + ) + numpyro.sample("gamma", dist.ZeroSumNormal(scale, event_shape=(10,))) + + if guide_fn == "custom": + guide_fn = guide + + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + sample_dims = ("sample",) if svi else ("chain", "draw") + + inference_data = from_numpyro_func(**result, coords={"groups": np.arange(10)}) + assert inference_data.posterior.gamma.dims == sample_dims + ("groups",) assert "groups" in inference_data.posterior.gamma.coords - def test_mcmc_inferred_dims_univariate(self): + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_inferred_dims_univariate(self, svi, guide_fn): import jax.numpy as jnp import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): alpha = numpyro.sample("alpha", dist.Normal(0, 1)) @@ -367,33 +482,130 @@ def model(): mu = numpyro.deterministic("mu", alpha) return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1])) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - with pytest.raises(ValueError): - from_numpyro(mcmc, coords={"obs_idx": np.arange(3)}) + def guide(): + alpha_loc = numpyro.param("alpha_loc", jnp.array(0.0)) + alpha_scale = numpyro.param( + "alpha_scale", jnp.array(1.0), constraint=dist.constraints.positive + ) + sigma_loc = numpyro.param( + "sigma_loc", jnp.array(1.0), constraint=dist.constraints.positive + ) + + alpha = numpyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale)) + numpyro.sample("sigma", dist.HalfNormal(sigma_loc)) + with numpyro.plate("obs_idx", 3): + numpyro.deterministic("mu", alpha) + + if guide_fn == "custom": + guide_fn = guide - def test_mcmc_extra_event_dims(self): + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + with pytest.raises(ValueError): + from_numpyro_func(**result, coords={"obs_idx": np.arange(3)}) + + @pytest.mark.parametrize( + "svi,guide_fn", + [ + (False, None), # MCMC, guide ignored + (True, autoguide.AutoDelta), # SVI with AutoDelta + (True, autoguide.AutoNormal), # SVI with AutoNormal + (True, "custom"), # SVI with custom guide + ], + ) + def test_extra_event_dims(self, svi, guide_fn): import numpyro import numpyro.distributions as dist - from numpyro.infer import MCMC, NUTS def model(): gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,))) _ = numpyro.deterministic("gamma_plus1", gamma + 1) - mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) - mcmc.run(PRNGKey(0)) - inference_data = from_numpyro( - mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]} + def guide(): + scale = numpyro.param( + "gamma_scale", + 1.0, + constraint=dist.constraints.positive, + ) + gamma = numpyro.sample("gamma", dist.ZeroSumNormal(scale, event_shape=(10,))) + numpyro.deterministic("gamma_plus1", gamma + 1) + + if guide_fn == "custom": + guide_fn = guide + + result = self._run_inference(model, svi=svi, guide_fn=guide_fn) + from_numpyro_func = from_numpyro_svi if svi else from_numpyro + sample_dims = ("sample",) if svi else ("chain", "draw") + inference_data = from_numpyro_func( + **result, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]} ) - assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups") + assert inference_data.posterior.gamma_plus1.dims == sample_dims + ("groups",) assert "groups" in inference_data.posterior.gamma_plus1.coords - def test_mcmc_predictions_infer_dims( + def test_predictions_infer_dims( self, data, eight_schools_params, predictions_data, predictions_params ): inference_data = self.get_inference_data( data, eight_schools_params, predictions_data, predictions_params, infer_dims=True ) - assert inference_data.predictions.obs.dims == ("chain", "draw", "J") + sample_dims = ("sample",) if isinstance(data.obj, dict) else ("chain", "draw") + assert inference_data.predictions.obs.dims == (sample_dims + ("J",)) assert "J" in inference_data.predictions.obs.coords + + def _run_inference(self, model, svi, guide_fn): + from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO + from numpyro.optim import Adam + + if svi: + is_autoguide = isinstance(guide_fn, type) and issubclass(guide_fn, autoguide.AutoGuide) + guide = guide_fn(model) if is_autoguide else guide_fn + svi = SVI(model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO()) + svi_result = svi.run(PRNGKey(0), 10) + return { + "svi": svi, + "svi_result": svi_result, + "model": None if is_autoguide else model, + } + + else: + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(PRNGKey(0)) + return {"posterior": mcmc} + + +class TestSVIWrapper: + @pytest.fixture(scope="class", params=["numpyro_svi", "numpyro_svi_custom_guide"]) + def data(self, request, eight_schools_params, draws, chains): + class Data: + obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")[request.param] + + return Data + + def test_init_without_args_kwargs(self): + from numpyro.infer import Trace_ELBO + from numpyro.infer.svi import SVI, SVIRunResult + from numpyro.optim import Adam + + model = guide = lambda x: x + svi = SVI(model, guide, optim=Adam(0.05), loss=Trace_ELBO()) + svi_result = SVIRunResult(params=jax.numpy.ones(5), state=None, losses=jax.numpy.zeros(10)) + + posterior = SVIWrapper(svi, svi_result=svi_result) + assert isinstance(posterior._args, tuple) + assert isinstance(posterior._kwargs, dict) + + def test_get_samples(self, data, eight_schools_params): + svi_posterior = SVIWrapper( + data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params + ) + out = svi_posterior.get_samples(seed=0) + assert isinstance(out, dict) + for v in out.values(): # values are array-like + assert isinstance(v, (jax.numpy.ndarray | np.ndarray)) + + def test_sampler_attr(self, data, eight_schools_params): + svi_posterior = SVIWrapper( + data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params + ) + assert hasattr(svi_posterior, "sampler") + assert hasattr(svi_posterior.sampler, "model") diff --git a/src/arviz_base/__init__.py b/src/arviz_base/__init__.py index 373499d..cfbfd9f 100644 --- a/src/arviz_base/__init__.py +++ b/src/arviz_base/__init__.py @@ -13,7 +13,7 @@ from arviz_base.io_cmdstanpy import from_cmdstanpy from arviz_base.io_dict import from_dict from arviz_base.io_emcee import from_emcee -from arviz_base.io_numpyro import from_numpyro +from arviz_base.io_numpyro import from_numpyro, from_numpyro_svi from arviz_base.io_pystan import from_pystan from arviz_base.rcparams import rc_context, rcParams from arviz_base.reorg import ( @@ -49,6 +49,7 @@ "from_dict", "from_emcee", "from_numpyro", + "from_numpyro_svi", # rcparams "rc_context", "rcParams", diff --git a/src/arviz_base/__init__.pyi b/src/arviz_base/__init__.pyi index 6c34f4f..f239fea 100644 --- a/src/arviz_base/__init__.pyi +++ b/src/arviz_base/__init__.pyi @@ -25,7 +25,7 @@ from arviz_base.datasets import ( from arviz_base.io_cmdstanpy import from_cmdstanpy from arviz_base.io_dict import from_dict from arviz_base.io_emcee import from_emcee -from arviz_base.io_numpyro import from_numpyro +from arviz_base.io_numpyro import from_numpyro, from_numpyro_svi from arviz_base.io_pystan import from_pystan from arviz_base.rcparams import rc_context, rcParams from arviz_base.reorg import ( @@ -55,6 +55,7 @@ __all__ = [ "from_dict", "from_emcee", "from_numpyro", + "from_numpyro_svi", "rc_context", "rcParams", "extract", diff --git a/src/arviz_base/io_numpyro.py b/src/arviz_base/io_numpyro.py index 359af61..7d05fff 100644 --- a/src/arviz_base/io_numpyro.py +++ b/src/arviz_base/io_numpyro.py @@ -11,6 +11,71 @@ from arviz_base.utils import expand_dims +class SVIWrapper: + """A helper class for SVI to mimic numpyro.infer.MCMC methods.""" + + def __init__( + self, + svi, + *, + svi_result, + model_args=None, + model_kwargs=None, + num_samples: int = 1000, + ): + import jax + import numpyro + + self.svi = svi + self.svi_result = svi_result + self._args = model_args or tuple() + self._kwargs = model_kwargs or dict() + self.num_samples = num_samples + self.thinning = 1 + self.num_chains = 0 + self.sample_dims = ["sample"] + self.kind = "svi" + + self.numpyro = numpyro + self.prng_key_func = jax.random.PRNGKey + + def get_samples(self, seed=None, **kwargs): + """Mimics mcmc.get_samples().""" + key = self.prng_key_func(seed or 0) + if isinstance(self.svi.guide, self.numpyro.infer.autoguide.AutoGuide): + return self.svi.guide.sample_posterior( + key, + self.svi_result.params, + *self._args, + sample_shape=(self.num_samples,), + **self._kwargs, + ) + # if a custom guide is provided, sample by hand + predictive = self.numpyro.infer.Predictive( + self.svi.guide, params=self.svi_result.params, num_samples=self.num_samples + ) + samples = predictive(key, *self._args, **self._kwargs) + return samples + + @property + def sampler(self): + """Mimics mcmc.sampler.model.""" + + class Sampler: + def __init__(self, model): + self._model = model + + @property + def model(self): + return self._model + + return Sampler(getattr(self.svi.guide, "model", self.svi.model)) + + def get_extra_fields(self, **kwargs): + """Mimics mcmc.get_extra_fields().""" + return dict() + + def _add_dims(dims_a, dims_b): """Merge two dimension mappings by concatenating dimension labels. @@ -230,7 +295,10 @@ def arbitrary_element(dct): observations = {} if self.model is not None: trace = self._get_model_trace( - self.model, self._args, self._kwargs, key=jax.random.PRNGKey(0) + self.model, + model_args=self._args, + model_kwargs=self._kwargs, + key=jax.random.PRNGKey(0), ) observations = { name: site["value"] @@ -239,14 +307,17 @@ def arbitrary_element(dct): } self.observations = observations if observations else None - def _get_model_trace(self, model, args, kwargs, key): + def _get_model_trace(self, model, model_args, model_kwargs, key): """Extract the numpyro model trace.""" + model_args = model_args or tuple() + model_kwargs = model_kwargs or dict() + # we need to use an init strategy to generate random samples for ImproperUniform sites seeded_model = self.numpyro.handlers.substitute( self.numpyro.handlers.seed(model, key), substitute_fn=self.numpyro.infer.init_to_sample, ) - trace = self.numpyro.handlers.trace(seeded_model).get_trace(*args, **kwargs) + trace = self.numpyro.handlers.trace(seeded_model).get_trace(*model_args, **model_kwargs) return trace @requires("posterior") @@ -321,8 +392,13 @@ def translate_posterior_predictive_dict_to_xarray(self, dct, dims): shape = ary.shape if shape[0] == self.nchains and shape[1] == self.ndraws: data[k] = ary - elif shape[0] == self.nchains * self.ndraws: + elif ( + shape[0] == self.nchains * self.ndraws + and getattr(self.posterior, "kind", "") != "svi" + ): data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:])) + elif getattr(self.posterior, "kind", "") == "svi": + data[k] = ary else: data[k] = expand_dims(ary) warnings.warn( @@ -359,12 +435,17 @@ def priors_to_xarray(self): else: prior_vars = self.prior.keys() prior_predictive_vars = None + + # dont expand dims for SVI + expand_dims_func = ( + expand_dims if getattr(self.posterior, "kind", "") != "svi" else lambda x: x + ) priors_dict = { group: ( None if var_names is None else dict_to_dataset( - {k: expand_dims(self.prior[k]) for k in var_names}, + {k: expand_dims_func(self.prior[k]) for k in var_names}, inference_library=self.numpyro, coords=self.coords, dims=self.dims, @@ -473,6 +554,8 @@ def from_numpyro( ): """Convert NumPyro data into a DataTree object. + For a usage example read :ref:`numpyro_conversion` + If no dims are provided, this will infer batch dim names from NumPyro model plates. For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}` can be provided in numpyro.sample, i.e.:: @@ -487,9 +570,6 @@ def from_numpyro( There is also an additional `extra_event_dims` input to cover any edge cases, for instance deterministic sites with event dims (which dont have an `infer` argument to provide metadata). - For a usage example read the - :ref:`Creating InferenceData section on from_numpyro ` - Parameters ---------- posterior : numpyro.mcmc.MCMC @@ -538,3 +618,104 @@ def from_numpyro( extra_event_dims=extra_event_dims, num_chains=num_chains, ).to_datatree() + + +def from_numpyro_svi( + svi, + *, + svi_result, + model_args=None, + model_kwargs=None, + prior=None, + posterior_predictive=None, + predictions=None, + constant_data=None, + predictions_constant_data=None, + log_likelihood=None, + index_origin=None, + coords=None, + dims=None, + pred_dims=None, + extra_event_dims=None, + model=None, + num_samples: int = 1000, +): + """Convert NumPyro SVI results into a DataTree object. + + For a usage example read :ref:`numpyro_conversion` + + If no dims are provided, this will infer batch dim names from NumPyro model plates. + For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}` + can be provided in numpyro.sample, i.e.:: + + # equivalent to dims entry, {"gamma": ["groups"]} + gamma = numpyro.sample( + "gamma", + dist.ZeroSumNormal(1, event_shape=(n_groups,)), + infer={"event_dims":["groups"]} + ) + + There is also an additional `extra_event_dims` input to cover any edge cases, for instance + deterministic sites with event dims (which dont have an `infer` argument to provide metadata). + + Parameters + ---------- + svi : numpyro.infer.svi.SVI + Numpyro SVI instance used for fitting the model. + svi_result : numpyro.infer.svi.SVIRunResult + SVI results from a fitted model. + model_args : tuple, optional + Model arguments, should match those used for fitting the model. + model_kwargs : dict, optional + Model keyword arguments, should match those used for fitting the model. + prior : dict, optional + Prior samples from a NumPyro model + posterior_predictive : dict, optional + Posterior predictive samples for the posterior + predictions : dict, optional + Out of sample predictions + constant_data : dict, optional + Dictionary containing constant data variables mapped to their values. + predictions_constant_data : dict, optional + Constant data used for out-of-sample predictions. + index_origin : int, optional + coords : dict, optional + Map of dimensions to coordinates + dims : dict of {str : list of str}, optional + Map variable names to their coordinates. Will be inferred if they are not provided. + pred_dims : dict, optional + Dims for predictions data. Map variable names to their coordinates. Default behavior is to + infer dims if this is not provided + extra_event_dims : dict, optional + Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to + their coordinates. + num_chains : int, default 1 + Number of chains used for sampling. Ignored if posterior is present. + + Returns + ------- + DataTree + """ + posterior = SVIWrapper( + svi, + svi_result=svi_result, + model_args=model_args, + model_kwargs=model_kwargs, + num_samples=num_samples, + ) + with rc_context(rc={"data.sample_dims": ["sample"]}): + return NumPyroConverter( + posterior=posterior, + prior=prior, + posterior_predictive=posterior_predictive, + predictions=predictions, + constant_data=constant_data, + predictions_constant_data=predictions_constant_data, + log_likelihood=log_likelihood, + index_origin=index_origin, + coords=coords, + dims=dims, + pred_dims=pred_dims, + extra_event_dims=extra_event_dims, + num_chains=0, + ).to_datatree() diff --git a/src/arviz_base/io_numpyro.pyi b/src/arviz_base/io_numpyro.pyi index c50c751..2d46285 100644 --- a/src/arviz_base/io_numpyro.pyi +++ b/src/arviz_base/io_numpyro.pyi @@ -14,6 +14,21 @@ from arviz_base.base import dict_to_dataset, requires from arviz_base.rcparams import rc_context, rcParams from arviz_base.utils import expand_dims +class SVIWrapper: + def __init__( + self, + svi, + *, + svi_result, + model_args=..., + model_kwargs=..., + num_samples: int = ..., + ) -> None: ... + def get_samples(self, seed=..., **kwargs) -> None: ... + @property + def sampler(self) -> None: ... + def get_extra_fields(self, **kwargs) -> None: ... + def _add_dims( dims_a: dict[str, list[str]], dims_b: dict[str, list[str]] ) -> dict[str, list[str]]: ... @@ -46,7 +61,7 @@ class NumPyroConverter: extra_event_dims: dict | None = ..., num_chains: int = ..., ) -> None: ... - def _get_model_trace(self, model, args, kwargs, key) -> None: ... + def _get_model_trace(self, model, model_args, model_kwargs, key) -> None: ... def posterior_to_xarray(self) -> None: ... def sample_stats_to_xarray(self) -> None: ... def log_likelihood_to_xarray(self) -> None: ... @@ -77,3 +92,23 @@ def from_numpyro( extra_event_dims: dict | None = ..., num_chains: int = ..., ) -> DataTree: ... +def from_numpyro_svi( + svi: numpyro.infer.svi.SVI, + *, + svi_result: numpyro.infer.svi.SVIRunResult, + model_args: tuple | None = ..., + model_kwargs: dict | None = ..., + prior: dict | None = ..., + posterior_predictive: dict | None = ..., + predictions: dict | None = ..., + constant_data: dict | None = ..., + predictions_constant_data: dict | None = ..., + log_likelihood=..., + index_origin: int | None = ..., + coords: dict | None = ..., + dims: dict[str, list[str]] | None = ..., + pred_dims: dict | None = ..., + extra_event_dims: dict | None = ..., + model=..., + num_samples: int = ..., +) -> DataTree: ...