diff --git a/CHANGELOG.md b/CHANGELOG.md index 4add4b5ca..a46e6d250 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file. - [#1432](https://github.com/pints-team/pints/pull/1432) Added 2 new stochastic models: production and degradation model, Schlogl's system of chemical reactions. Moved the stochastic logistic model into `pints.stochastic` to take advantage of the `MarkovJumpModel`. - [#1420](https://github.com/pints-team/pints/pull/1420) The `Optimiser` class now distinguishes between a best-visited point (`x_best`, with score `f_best`) and a best-guessed point (`x_guessed`, with approximate score `f_guessed`). For most optimisers, the two values are equivalent. The `OptimisationController` still tracks `x_best` and `f_best` by default, but this can be modified using the methods `set_f_guessed_tracking` and `f_guessed_tracking`. - [#1417](https://github.com/pints-team/pints/pull/1417) Added a module `toy.stochastic` for stochastic models. In particular, `toy.stochastic.MarkovJumpModel` implements Gillespie's algorithm for easier future implementation of stochastic models. +- [#1413](https://github.com/pints-team/pints/pull/1413) Added classes `pints.ABCController` and `pints.ABCSampler` for Approximate Bayesian computation (ABC) samplers. Added `pints.RejectionABC` which implements a simple rejection ABC sampling algorithm. ### Changed - [#1439](https://github.com/pints-team/pints/pull/1439), [#1433](https://github.com/pints-team/pints/pull/1433) PINTS is no longer tested on Python 3.5. Testing for Python 3.10 has been added. diff --git a/docs/source/abc_samplers/base_classes.rst b/docs/source/abc_samplers/base_classes.rst new file mode 100644 index 000000000..e70b022bb --- /dev/null +++ b/docs/source/abc_samplers/base_classes.rst @@ -0,0 +1,8 @@ +********************** +ABC sampler base class +********************** + +.. currentmodule:: pints + +.. autoclass:: ABCSampler +.. autoclass:: ABCController \ No newline at end of file diff --git a/docs/source/abc_samplers/index.rst b/docs/source/abc_samplers/index.rst new file mode 100644 index 000000000..96b4e9fc9 --- /dev/null +++ b/docs/source/abc_samplers/index.rst @@ -0,0 +1,16 @@ +************ +ABC samplers +************ + +.. currentmodule:: pints + +Pints provides a number of samplers for Approximate Bayesian +Computation (ABC), all implementing the :class:`ABCSampler` +interface, that can be used to sample from a stochastic model +given a :class:`LogPrior` and a :class:`ErrorMeasure`. + + +.. toctree:: + + base_classes + rejection_abc \ No newline at end of file diff --git a/docs/source/abc_samplers/rejection_abc.rst b/docs/source/abc_samplers/rejection_abc.rst new file mode 100644 index 000000000..4bbf4632a --- /dev/null +++ b/docs/source/abc_samplers/rejection_abc.rst @@ -0,0 +1,7 @@ +********************* +Rejection ABC sampler +********************* + +.. currentmodule:: pints + +.. autoclass:: RejectionABC \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 543eb98ba..fc85f8631 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ Contents .. toctree:: + abc_samplers/index boundaries core_classes_and_methods diagnostics @@ -78,10 +79,10 @@ Sampling - SMC -#. Likelihood free sampling (Need distance between data and states, e.g. least squares?) +#. :class:`ABC sampling` - - ABC-MCMC - - ABC-SMC + - :class:`RejectionABC`, requires a :class:`LogPrior` that can be sampled + from and an error measure. #. 1st order sensitivity MCMC samplers (Need derivatives of :class:`LogPDF`) diff --git a/examples/README.md b/examples/README.md index c795ee6d1..923b08969 100644 --- a/examples/README.md +++ b/examples/README.md @@ -77,6 +77,9 @@ relevant code. - [Ellipsoidal nested sampling](./sampling/nested-ellipsoidal-sampling.ipynb) - [Rejection nested sampling](./sampling/nested-rejection-sampling.ipynb) +### ABC +- [Rejection ABC sampling](./sampling/rejection-abc.ipynb) + ### Analysing sampling results - [Autocorrelation](./plotting/mcmc-autocorrelation.ipynb) - [Customise analysis plots](./plotting/customise-pints-plots.ipynb) diff --git a/examples/sampling/rejection-abc.ipynb b/examples/sampling/rejection-abc.ipynb new file mode 100644 index 000000000..6cc700f81 --- /dev/null +++ b/examples/sampling/rejection-abc.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rejection ABC\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "PINTS can be used to perform inference for stochastic forward models. Here, we perform inference on the [stochastic degradation model](../toy/model-stochastic-degradation.ipynb) using Approximate Bayesian Computation (ABC). This model has only a single unknown parameter -- the rate at which chemicals degrade. We use the \"rejection ABC\" method to estimate this unknown and provide a measure of uncertainty in it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we load the stochastic degradation model and perform 10 simulations from it. The variation inbetween runs is due to the inherent stochasticity of this type of model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import pints\n", + "import pints.toy as toy\n", + "import pints.toy.stochastic\n", + "import pints.plot\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nO3de3xdZZ3v8c9v59I0Se9NL7SUAlNQILWUcJcRAbECiuUAhQpCYabMjBfGwzlax3OGF8fRg6M4Ol7pcQioyBQRFFFQZFAcbmNbKqU3wEJLIaEptZc0zXX/zh9rpSTZ2ZfsZO2dZH3fr1fI3s9+1np+aekvz372s37L3B0REYmPRLEDEBGRwlLiFxGJGSV+EZGYUeIXEYkZJX4RkZgpLXYAuZg6darPnTu32GGIiIwoa9as2eXuNX3bR0Tinzt3LqtXry52GCIiI4qZbeuvXUs9IiIxo8QvIhIzSvwiIjGjxC8iEjNK/CIiMRNZ4jezw83scTPbaGYbzOzGsH2ymT1qZi+F3ydFFYOIiKSKcsbfCdzk7scBpwEfM7PjgBXAY+4+D3gsfC4iIgUS2T5+d28AGsLH+81sEzALuBg4O+x2F/Bb4DNRxHDlT7/D65VTDz0/tfUPnNn8m4zHdFmSpHXlPaZRgvXzx9pRXkVHeTUAL+8/g837zsl4nosXzGLpqXPyjkNEJJ2CrPGb2VzgROBZYHr4SwGgEZie5pjlZrbazFY3NTUNOobXSmbxbMXJmeN0o8QH80eSxEn9pVHS1U5Z+wEApozZzl+MeyrjWTY27ONn614fRBwiIulFfuWumVUDPwH+3t33mdmh19zdzazfO8G4+0pgJUBdXV1ed4u59VsrATji0ec47ae/o51pfHjp2rT9H7gteO2ymxbmMxzLHlkGQP2i+l7tq24JVrOuvvlW1qxdynTg6otOT3ueJbc/ndf4IiK5iHTGb2ZlBEn/bne/P2x+08xmhq/PBHZGGYOIiPQW5a4eA/4N2OTuX+3x0oPANeHja4CfRRWDiIikinKp50zgamC9ma0L2/4BuBW418yuB7YBl0cYg4iI9BHlrp7/BCzNy+dGNa6IiGSmK3dFRGJGiV9EJGaU+EVEYkaJX0QkZpT4RURiRolfRCRmzD2vaggFVVdX5/ncbH3b+06ktfEgFTPGcvnffIXGceOp7ShL27+jrQtPOpZ4exfqaW8lWbCvJKfxtuzeTEvnQSpLx/ZqL9nxGxLte6ioOow55z5LxeQDJFuPTh3/z++m4633sbFhHy1tnVSOCXbb7pxUQuPUaKtrqCicyOhjZmvcva5v+6ie8Y9/319SMSNIwrWvbWXG/n0Z+ydKrFfS/1N1gmem5P5HNHnslJSkD9BadRjJ8okA7H6xitbdVSl9Ssa+Stmk/wRganX5oaRffTDJtD/nXy00FyoKJxIvkRdpK6ZJn/46kz4dPK771Mep2/ZH/uZfvpnz8Rf9Yh2MKWHxTQsGFcfbxdu+yapbVtDyInzw5lt79VmzdikAJ13cu0Bcd+G4z92QX+G4XKgonEi8jOoZv4iIpFLiFxGJGSV+EZGYUeIXEYkZJX4RkZhR4hcRiRklfhGRmIny1ot3mNlOM3uhR9sCM3vGzNaZ2WozOyWq8UVEpH9RzvjvBBb1aftn4BZ3XwD8Y/hcREQKKMpbLz5hZnP7NgPjw8cTgDeiGr8/b40dz3c/9fGU9qOadnDsm9tS2pN//a+8OLGKC+79fdDQsR86grIP5294hsXrfpd5wKoaGDeDK8IaPr/6xinsHjuR1kQp37n0fADKEmWUJcqouWgrZVMO8ssfzgeg5eWJHNg8mX2JyXRSyjeWdgIwxg8yxg/2O1zplCmUTqvJGNI7zzyb+ef1/X0clG3IdgWv6vmIjA6FLtnw98CvzOwrBO82zkjX0cyWA8sB5swZfLI5rmYsG5tSa/XsrpoA0G/if/+mdfDOsFxDYgyUAR37eHF6EE/GxN9+IPg+bgZTxk6Bg28BMLGjlT1lFQAkPUlHsoOyRBktL0+kMjy0bMpBKoEDmycHSd6C+j+dlIKN7TfxJ1ta6ISMib/p1VcAUhL/xQtmpf85Qhsbgj87JX6RkS/S6pzhjP8hdz8hfP6vwO/c/Sdmdjmw3N3Py3aefKtz5qK+vh6AZcuWZey38/bnAZh2w3wWP/cSAA+cOC/DiS8Mvi/7Rdoub9fwqe/Vfqhuz8If9Wrvrtuz+KbUuj3brv4oAEf84Ptpx1t1ywoAlvSpE5SL7ncDq244fcDHikhxDJfqnNcA94ePfwzow10RkQIrdOJ/A3hP+Pgc4KUCjy8iEnuRrfGb2T3A2cBUM9sB3Az8NfB1MysFWgnX8EVEpHCi3NVzZZqXTopqTBERyU5X7oqIxIwSv4hIzCjxi4jEjBK/iEjMKPGLiMRMpFfuDpWor9xtbGxkxowZGft1NBzgqPZpzJ/5Dq6bnWRLpXHCxMr0BzSuD8o2lFel7bKFDlpIUtnn9++ksk5KE9CZDBssAWYkOksxT+CWTDlXIgmnHlzDew4+A0BXidFVagB0jmmlo+IgFbuDej+tk3PfzDW5YjI1ldPY+MZefr9jIXvtgpQ+quEjMjwNlyt3h53a2tqsSR/greQ+tpbvBOD9O9o4tiXLL8yqmoxJH2AyiZSkD9CaTLyd9AE8eOKJZL9JH2B72SyeHRvslDV3SrqC+Eo6yyhtq8gcaxotnS3sbt0NwIzK1zhr9tqUPhsb9vGzda/ndX4RKY5CF2kbdurq6qirS/mFmKK7ps+0ZfNZcvvzLNkB0y7MUKuHTK8NQA41f4CwftBULlz0eK+6Pd31fZbetPBQrZ6lOdbqWfbIMuiEpQvrWbN2KZMnwtUX9a7Vk62ip4gMP7Gf8YuIxI0Sv4hIzCjxi4jEjBK/iEjMKPGLiMSMEr+ISMwo8YuIxIwSv4hIzESW+M3sDjPbaWYv9Gn/hJltNrMNZvbPUY0vIiL9i3LGfyewqGeDmb0XuBh4l7sfD3wlwvFFRKQfkRZpM7O5wEPufkL4/F5gpbv/ZiDnibJIW656FnPraDiAt3dh5SUZjzl+3js587JzBznwhUHBtxm1GbstnnE9G8pncnx7A21vtpPsSJIoS7B3zGF02RhKvI2O5F7cOynz/uv99NVFWO8Ho6ymjUSZU+q95wpd7nR4CV2JMgBmNTeTeHOIylXkQAXiRNIbLkXajgHOMrNnzex3ZnZyuo5mttzMVpvZ6qampgKG2L+exdwS1WVZk/6u9j1seGnTEAx8adakD3BJ8/Mc394AQElVCYmy4K92TGczJd4WdEpUEdznfuA6WkpJdlhKe4IkZdYFQENiOq9XV+d1/nyoQJxIfgpdpK0UmAycBpwM3GtmR3k/bzvcfSWwEoIZf0Gj7Eeuxdy6rfziN4do4GXBVxZXh1+ZBAXbJrL4poU5Db3skWDc+kX13HXFhwG45t9/2rtTjyJy73viQaCUVTf0LuQWFRWIE8lP1sRvZhXARcBZwGHAQeAF4BfuvmGA4+0A7g8T/X+ZWRKYChR/Si8iEhMZl3rM7BbgSeB04FngduBeoBO41cweNbP5Axjvp8B7w3MfA5QDu/KIW0RE8pRtxv9f7n5zmte+ambTgH4/WTOze4CzgalmtgO4GbgDuCPc4tkOXNPfMo+IiEQnY+J394x3/3D3ncDONK9dmeawq3ILTUREopDLGv9s4Erg3fRZ4wceds9xb6CIiAwLGRO/mdUDs4CHgC8RzO4rCLZlLgI+Z2Yr3P2JqAMVEZGhkW3Gf5u7v9BP+wvA/WZWTpo1fhERGZ4y7urpTvpmdmPf18zsRndvd/eXowpORESGXq5X7l7TT9u1QxiHiIgUSMZaPWZ2JbCU4IPd3/d4aRyQdPdBFqLJzXCo1TNQK7/4TXa172Fq+UQA5o09nOMrj4x0zMoFNVSfOjPt6w/ctpZdO5qZOju3sgpbdm+mpfMglaVjqdj6K5LJvSQSE3p38u7/GPdf+wF2lB/OrPbXsp67s7WMjpbyzJ0stURE37GDkVMdtW09R257KuPh2w8rYevhmUtv5GNq1RimjR8DwAVHXcBlx1w25GOI5CJdrZ5sa/xPAQ0EV9fe1qN9P/D80IU3+hw/752HavXs6tgbtEWY+DsammmBjIn/mFOmD+ick8dOgYNvBU+q5pI48GpqJwM8SL212zbAEeHvggxKyoLaPlkTv3vm5G9g/QzWNLkGqM2Y+Cfud3ija8gTf0tbJ7uAaePHsGX3FgAlfhl2siX+7e6+jeDK3X6ZmekirFRnXnYuZxK8Iaqvrwdg2rKBXOQ8MDtvz/57+PizZnH8WbMii2Hb1R+F/9zGET/4fsZ+a9YuBeCkRT9K36lHDaCBOuvB56B6Fl/67s/T9ll1ywoAvrTs1gGfP5Mltz8N7VC/6PRDtY5Ehptsa/yPhzdO6bVzx8zKzewcM7uL/tf/RURkmMo2418EXAfcY2ZHAnuAsQS/MH4NfM3dn4s2RBERGUrZSja0At8Gvm1mZQRr/QfdfU8hghMRkaGX841Y3L3D3RuADjO7yswGvvgqIiJFl1PiD9f0F5vZjwl2+ZwLfDfSyEREJBLZavWcT1Cg7XzgceD7wMnuru0KIiIjVLYZ/yPAUcC73f0qd/85oGqcIiIjWLbEvxB4GvhNeLet64GcrngxszvMbGd405W+r91kZm5mUwcesoiIDEa2Im3r3H2Fux9NcAetBUCZmT1sZsuznPtOgu2gvZjZ4QRLR9vzC1lERAYj641Yurn7U8BTYaXO84ArgJUZ+j9hZnP7eelfgE8DPxtQpJJVR0Nz1it4s9XzGazWzZuDK3gz9Vm0ifbJB3jy7hMBGLe1hgkvzujdqfENxs/eyyQuzDxg7aVQl/qR02sV4RW8QOW4cion9i4P0TT/HNpbD/KNB37d72lP2rmNMxr/lHnsfrzjjX28UHE0S26HV8v30WqvcWr9f0vp17Oez0ilOkQjV7abrc/t2+buSXf/tbtfZ4HZuQ5mZhcDr7v7H3Pou9zMVpvZ6qamplyHiK3KBTWUzcxcfK2joZmWddH9WY6/6CIq3vGOrP3Gba2hfHcVAO2TD7D/qNSYWt9y9u2YkNLeS+N6WH9fSvNFE8ZxeGvwuKOti5b97Sl9KidMpLxibL+nfb16ImumHZHlp+jfhINNnNAa/MKY0HUKFX54Sp+Wtk52HWjL6/zDxZbdW/jl1l8WOwzJU7YZ/5fNLEEwO18DNBHcgesvCG6kfh7BEtCObAOZWSXwDwTLPFm5+0rCdxR1dXWqBZRF9akzs87kc6nnMxiTllzOpCWXD+iYNWuXwnQ44iO96/scetewLEPdn/r+3w185j1/wWfCxw/cthZwFn9oXs4xLX7uJZg6iSWLB17HZ9UtK5gF3HLD6aQrcdWzns9IpTpEI1u2K3cvM7PjgI8QlG6YCbQAm4BfAl8Mr+7NxdHAkcAfLai4OBtYa2anuHtjnvGLiMgAZV3jd/eNwOcGO5C7rwemdT83s1eBOnffNdhzi4hI7nIu2TBQZnYPwVbQY81sR7gVVEREiiznXT0D5e5XZnl9blRji4hIepHN+EVEZHjKecZvZrOAI3oe4+5PRBGUiIhEJ6fEb2ZfApYAG4GusNkBJX4RkREm1xn/h4Fj3X1kX3UiIiI5r/FvBcqiDERERAoj1xl/C7DOzB4DDs363f2TkUQlIiKRyTXxPxh+SZ4aGxupr69Paa+traWurq5gcQyHQm59NTdvDEo39NC6aBPJlhZ2hIXc+ncAvAvqj0nbo2LmYXQlxvCrH3cAMLasharyZgBmHJjArOaJqQfNuJ4N5TNZ/PD6lJcuaX6eq5v/kD6khvB7mnISAP/41l7uaz+dJbenP81wl6kAHQzvInQqLpdj4nf3u8ysHOj+F7bF3TuiC2t0qa2t7be9sTGoVFGoxF+5oIaWLH06GpppgYIl/hnTP0h/9TpKp0yhM9vBJWVvbzVIo7xtP+1jgJIEHcky6KikqryZ5vI2Gtnbb+K/pPl56Kfe3YbymVBN5sSfg3nJV7m0HP4PlwzqPMU0oeuUtHfmaGnrZBcMy8S/ZfcWACX+XDqZ2dnAXcCrgAGHm9k12s6Zm7q6un6Te3/vAKI0HAq59TVr1pXMmpXxWr9B6S72dsQPvs8Dt62lAzjppoVvv8O44Ecpx1wdfvW1+LmXgKnwgV+kH3D7iuD7svQF3srrL+R4YNWykVukLV0BOhjeRehUXC6Q61LPbcD57r4FwMyOAe4BTooqMBERiUauu3rKupM+gLu/iHb5iIiMSLnO+Feb2feAH4bPPwKsjiYkERGJUq6J/2+BjwHd2zd/D3w7kohERCRSue7qaQO+Gn6JiMgIljHxm9m97n65ma0nqM3Ti7vPjywyERGJRLYZ/43h94uiDkRERAoj464ed+++DvHv3H1bzy/g7zIda2Z3mNlOM3uhR9uXzWyzmT1vZg+YWT+XTYqISJRy3c75vn7aPpDlmDuBRX3aHgVOCJeIXgQ+m+P4IiIyRLKt8f8twcz+KDPreUnnOODJTMe6+xNmNrdP2697PH0GuHQgwY5GPWv4FLpuT1y0bt7Mtqs/SmtFMA/ZdvXXaF20ifbJB3gyYy2g3vZN+QyvlM3h3EfWpO3Teeq7cODbj9wDwKktT3N2y297dyrtAnf4wfE5j93yYjUHNozPuX8xLXLHHeofSX1t7oQO3vuefuphRKH2UqhLvVJ3y+4tQ3YF70it+5Ntjf9HwMPA/wVW9Gjf7+67Bzn2dcCqdC+a2XJgOcCcOXMGOdTw1LOGT6Hr9sTF+Iv6/3hq3NYa9g/wXO8++EzWPgkgGT7eXjYHKklN/FhQ+CRHZVPbqaR5xCR+O/Sf3vaVlPHq3gIF0RgW2OuT+C846oIhG2Ik1/0x95TNOuk7m00DKrqfu/v2LP3nAg+5+wl92j8H1AGXeA4B1NXV+erVo/t6se5Z/7Jlxa0l0l2rZ9oNo2/D1gO3rQVg8U0LCzJeUNsHHjhx3qDO011X6KSFqXWFRpLvXBL8Ev7b+x+KfrDu6qjLMtRVGqTudw31iwpbc2sgzGyNu6fMJnNa4zezD5rZS8ArwO8IirU9nGcg1xLsEvpILklfRESGVq4f7v4TcBrworsfCZxLsEY/IGa2CPg08CF3z1YhWEREIpBr4u9w97eAhJkl3P1xgqWatMzsHuBp4Fgz22Fm1wPfJPhg+FEzW2dm3x1M8CIiMnC51urZY2bVwBPA3Wa2EziQ6QB376/I+r8NMD4RERliuc74Lya47+6ngEeAPwEfjCooERGJTtYZv5mVEOzMeS/BTrW7Io9KREQik3XG7+5dQNLMJhQgHhERiViua/zNwHoze5Qea/vu/sn0h4iIyHCUa+K/P/zqSXvwRURGoFwT/0R3/3rPBjO7MV1nGdk6GpoPXcE7UJULaqg+deYQRzR0du1oPnQFb+RjzTN2jIWzHnxuUOcpGRuUBOga5HmK5eTdzllvQVt5DUn/M99Yeh0AY/wgY/xgSv85iXKOKqlIaR+Qxjeg/QD8KPdaTAN1Be08fryxrDHY2X7BYWdx2fn/Etl4QynXXT3X9NN27RDGIcNE5YIaymbmV0Sro6GZlnVNQxzR0DnmlOlMnV2gAmEECW92al6LlR1j4Q+Tw8I9Y48lYZMA6KSUNhub0n+vd7E92T74gatqoLxq8OfJYNZO570bgoWPLd7KL9/4faTjDaVs1TmvBJYCR5rZgz1eGgcMtkibDEPVp87Me8ae77uEQjn+rFkcf9asgo23eIjOs2btl4GRWaunu17R4g/NY8ntbcB7WHXD6WnrJq26JagFecTNtxY0znxsu/qjvAOov/b7LLtzZBVXzLbU8xTQAEwFbuvRvh8Y3v/KRUSkXxkTf3inrW3A6YUJR0REopZrdc5LzOwlM9trZvvMbL+Z7Ys6OBERGXq57ur5Z+CD7r4pymBERCR6ue7qeVNJX0RkdMh1xr/azFYBPwXauhvdve9FXSIiMszlmvjHE1TnPL9Hm5N6Na+IiAxzOSV+dy/ujWBFRGTI5Lqr5xgze8zMXgifzzez/5XlmDvMbGf3MWHbZDN7NNwh9KhZeBmfiIgUTK4f7v4/4LNAB4C7Pw9ckeWYO4FFfdpWAI+5+zzgsfC5iIgUkLlnL7JpZn9w95PN7Dl3PzFsW+fuC7IcN5fgJi4nhM+3AGe7e4OZzQR+6+7HZhu/rq7OV69enf2nGcHq6+tpbGxkxowZxQ4lbx0NB/D2Lqy8ZEDHJarLKBlXHlFUI9+06d+nvPxN2tunH2qrqqpm3LhxRYwqNyuar2Br1zSOKtnJgfZOkklIJKAkCebg1ueAZAdGEizXOSkYgPU9UW8lpaWUlOb6kWaq8vIplI+Z1qutbdNmki0tJCoraWnfz8INT/L+db9NOXYKJdQwsH8TPY058jBmrPxFXsea2Rp3T6knkeufxC4zO5qwFLOZXUpQymGgprt793GNwPR0Hc1sObAcYM6cOXkMNbLU1tYWO4RBS1SXkWwe2DHe3kWyGSX+DA4cOKHX8/b2dqB5RCT+95RvgrDmWllJgg6SQD8Jv5slcA+TeS7CXx6Z+nuyi65O8k78XV0ttLeTkvhLpkw59Pi16XMBUhJ/Cw50DSrxRyHXGf9RwErgDODPwCvAVe7+apbj5tJ7xr/H3Sf2eP3P7p51nT8OM/646i7sNu2G+UWOZOSor68HYNmy0bfnYsntTwOw6obcqsSkK/bWU3fhtyV5Fn5bs3YpkLlIXncxugdOnNerfdkjwd9R/aL6vMYerEHN+N19K3CemVUBCXffn2ccb5rZzB5LPTvzPI+IiOQp1109XzSzie5+wN33m9kkM/unPMZ7kLdr+18D/CyPc4iIyCDk+gnKB9x9T/cTd/8zcEGmA8zsHuBp4Fgz22Fm1wO3Au8zs5eA88LnIiJSQLl+2lFiZmPcvQ3AzMYCYzId4O5Xpnnp3AHEJyIiQyzXxH838JiZdX9CsQy4K5qQREQkSrl+uPslM3uet2frn3f3X0UXloiIRCXnja3u/jDwcISxiIhIAegOXCIiMaM7cImIxEyuV+4+6e5nFiCefunK3dFr5+3P09HQTNnMagAqF9RQferMIkc1vOVT16m2tpa6upQLOIedJbc/zcaGfRw3c3xO/ee/1Eb1wSTNY9MvXpTv/HesvQkvr8krpnmLVlMx5QBNHUcB8PL+M9i875xefdbNLqd5TILqtmSv9tbEayRpIxFugixLJCgvzb0OEcAR5V08fNbFecU+2Fo9ugOXRKJyQQ0t4eOOhmZaQIk/i4HWdWpsbAQYEYn/4gWzBtR/56TsNXC6Kt85qEo5e16sYuIxwHiYMmY7QErin7a/q99jS308neGqeDLpdJCkPOfLp6KT64y/v0IT7u7XDX1IqTTjjwfV7YnGaK7tUwjf+quPA/Cx730zp7o96Qy0DtFQGGytHv0fIyIySuS6q2e2mT0Q3lFrp5n9xMxmRx2ciIgMvVwXm+oJCqwdFn79PGwTEZERJtfEX+Pu9e7eGX7dCeT3EbmIiBRVron/LTO7ysxKwq+rgLeiDExERKKRa+K/Dric4HaJDcClBIXaRERkhMl1V8824EMRxyIiIgWQ666eu8ys571yJ5nZHfkOamafMrMNZvaCmd1jZhX5nktERAYm16We+f3cgevEfAY0s1nAJ4G68CbsJcAV+ZxLREQGLteSDQkzmxQmfMxs8gCOTTfuWDPrACqBNwZxLhlFOhqaD13Bq7o9Q6exsfHQFbwjyXCpMdTa3MC3/urjzDn3JSom7+OR+0/L2L9ywkSqJk3u1XbRrL20tHfxg4cGVkDCy+bx0fd/bcAxZ5LrjP824Gkz+7yZfR54iqBi54C5++vAV4DtBB8U73X3X/ftZ2bLzWy1ma1uamrKZygZYSoX1Bwq1tbR0EzLOv29D4Xa2toBFXQbLhobG1m/fn2xw+Dok86kojqYgOx+eQatuzMXkGtvPUjL3j0p7VOrx1BZPpiqQUMnp1o9AGZ2HNBdmeg/3H1jXgOaTQJ+AiwB9gA/Bu5z9x+mO0a1euJHdXtkONYYeuC2tQAsvmlh2j6rblkBwJKbby1ITJkMtjonYaLPK9n3cR7wirs3hYHdD5wBpE38IiIydIpRH3Q7cJqZVZqZEdzHVzd4EREpkIInfnd/FrgPWAusD2NYWeg4RETiajA7c/Lm7jcDNxdjbBGRuCv+rWBERKSglPhFRGJGiV9EJGaU+EVEYkaJX0QkZpT4RURipijbOUVy0bNgmwzcaChyN9yKy+1qaaajrYuXP/9k2j7trSUku9r5wmc/NyRjVpZW8KnP/+8hOVc3JX4ZlioX1NBS7CBGsI6GZlpgRCf+2traYoeQonJcOS20Z+xTUlpVoGjyp8Qvw1L1qTNHdNIqttHwTqmurm5YlGQejbTGLyISM0r8IiIxo8QvIhIzSvwiIjGjxC8iEjNK/CIiMaPELyISM0VJ/GY20czuM7PNZrbJzE4vRhwiInFUrAu4vg484u6Xmlk5UFmkOEREYqfgid/MJgB/CVwL4O7tkOUaaBEZMNU6GrjRUN8oF8VY6jkSaALqzew5M/uemaUUtzCz5Wa22sxWNzU1FT5KkRGsckENZTOrix3GiNLR0EzLunjkGnP3wg5oVgc8A5zp7s+a2deBfe6etvxcXV2dr169umAxikj8dL87mnbD/CJHMnTMbI27pxQ8KsaMfweww92fDZ/fBywsQhwiIrFU8MTv7o3Aa2Z2bNh0LrCx0HGIiMRVsXb1fAK4O9zRsxVYVqQ4RERipyiJ393XASq0LSJSBLpyV0QkZpT4RURiRolfRCRmlPhFRGJGiV9EJGaU+EVEYqZY+/hFRIadnoXtRnPBNiV+ERGCRN8SPu5oaKYFlPhFREaz6lNnHkr0o72ctdb4RURiRolfRCRmlPhFRGJGiV9EJGaU+EVEYkaJX0QkZpT4RURipmiJ38xKzOw5M3uoWDGIiMRRMWf8NwKbiji+iEgsFeXKXTObDVwIfAH478WIQUQkk551e4qp/LAqJn7w6CE9Z7FKNnwN+DQwLl0HM1sOLAeYM2dOgcISEVeP5wgAAAULSURBVOldt2c0KnjiN7OLgJ3uvsbMzk7Xz91XAisB6urqvEDhiYj0qtszGhVjjf9M4ENm9irw78A5ZvbDIsQhIhJLBU/87v5Zd5/t7nOBK4D/cPerCh2HiEhcaR+/iEjMFLUev7v/FvhtMWMQEYkbzfhFRGJGiV9EJGaU+EVEYkaJX0QkZsx9+F8bZWZNwLY8D58K7BrCcEYC/czxoJ85HgbzMx/h7jV9G0dE4h8MM1vt7nXFjqOQ9DPHg37meIjiZ9ZSj4hIzCjxi4jETBwS/8piB1AE+pnjQT9zPAz5zzzq1/hFRKS3OMz4RUSkByV+EZGYGdWJ38wWmdkWM3vZzFYUO56omdnhZva4mW00sw1mdmOxYyoEMysxs+fM7KFix1IIZjbRzO4zs81mtsnMTi92TFEzs0+F/0+/YGb3mFlFsWMaamZ2h5ntNLMXerRNNrNHzeyl8PukoRhr1CZ+MysBvgV8ADgOuNLMjituVJHrBG5y9+OA04CPxeBnBrgR2FTsIAro68Aj7v4O4F2M8p/dzGYBnwTq3P0EoITgXh6jzZ3Aoj5tK4DH3H0e8Fj4fNBGbeIHTgFedvet7t5OcLevi4scU6TcvcHd14aP9xMkhFnFjSpaZjYbuBD4XrFjKQQzmwD8JfBvAO7e7u57ihtVQZQCY82sFKgE3ihyPEPO3Z8Advdpvhi4K3x8F/DhoRhrNCf+WcBrPZ7vYJQnwZ7MbC5wIvBscSOJ3NeATwPJYgdSIEcCTUB9uLz1PTOrKnZQUXL314GvANuBBmCvu/+6uFEVzHR3bwgfNwLTh+Kkoznxx5aZVQM/Af7e3fcVO56omNlFwE53X1PsWAqoFFgIfMfdTwQOMERv/4ercF37YoJfeocBVWYWu9u1erD3fkj234/mxP86cHiP57PDtlHNzMoIkv7d7n5/seOJ2JnAh8zsVYKlvHPM7IfFDSlyO4Ad7t79Tu4+gl8Eo9l5wCvu3uTuHcD9wBlFjqlQ3jSzmQDh951DcdLRnPj/AMwzsyPNrJzgw6AHixxTpMzMCNZ+N7n7V4sdT9Tc/bPuPtvd5xL8/f6Hu4/qmaC7NwKvmdmxYdO5wMYihlQI24HTzKwy/H/8XEb5B9o9PAhcEz6+BvjZUJy0qPfcjZK7d5rZx4FfEewCuMPdNxQ5rKidCVwNrDezdWHbP7j7L4sYkwy9TwB3hxOarcCyIscTKXd/1szuA9YS7Fx7jlFYusHM7gHOBqaa2Q7gZuBW4F4zu56gNP3lQzKWSjaIiMTLaF7qERGRfijxi4jEjBK/iEjMKPGLiMSMEr+ISMwo8Yv0EVa//Lvw8WHhVkKRUUPbOUX6COscPRRWghQZdUbtBVwig3ArcHR4EdxLwDvd/QQzu5agOmIVMI+gcFg5wUVzbcAF7r7bzI4mKAleA7QAf+3umwv/Y4j0T0s9IqlWAH9y9wXA/+zz2gnAJcDJwBeAlrBY2tPAR8M+K4FPuPtJwP8Avl2QqEVypBm/yMA8Ht7rYL+Z7QV+HravB+aHlVHPAH4clJUBYEzhwxRJT4lfZGDaejxO9nieJPj3lAD2hO8WRIYlLfWIpNoPjMvnwPD+B6+Y2WUQVEw1s3cNZXAig6XEL9KHu78FPBne9PrLeZziI8D1ZvZHYAOj/JafMvJoO6eISMxoxi8iEjNK/CIiMaPELyISM0r8IiIxo8QvIhIzSvwiIjGjxC8iEjP/Hx4yZTeUKwPDAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "np.random.seed(3)\n", + "\n", + "# Load a forward model\n", + "model = toy.stochastic.DegradationModel()\n", + "\n", + "# Create some toy data\n", + "real_parameters = model.suggested_parameters()\n", + "times = np.linspace(0, 10, 100)\n", + "\n", + "for i in range(10):\n", + " values = model.simulate(real_parameters, times)\n", + "\n", + " # Create an object with links to the model and time series\n", + " problem = pints.SingleOutputProblem(model, times, values)\n", + "\n", + " # Create a uniform prior parameter\n", + " log_prior = pints.UniformLogPrior([0.0], [0.3])\n", + "\n", + " # Set the error measure to be used to compare simulated to observed data\n", + " error_measure = pints.RootMeanSquaredError(problem)\n", + "\n", + " plt.step(times, values)\n", + "\n", + "\n", + "plt.xlabel('time')\n", + "plt.ylabel('concentration (A(t))')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fit using Rejection ABC\n", + "\n", + "The rejection ABC method can be used to perform parameter inference for stochastic models, where the likelihood is intractable. In ABC methods, typically, a distance metric comparing the observed data and the simulated is used. Here, we use the root mean square error (RMSE), and we accept a parameter value if the $RMSE<1$." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running...\n", + "Using Rejection ABC\n", + "Running in sequential mode.\n", + "Iter. Eval. Acceptance rate Time m:s\n", + "1 198 0.00505050505 0:00.2\n", + "2 213 0.00938967136 0:00.2\n", + "3 271 0.0110701107 0:00.2\n", + "20 1081 0.0185013876 0:00.8\n", + "40 2389 0.0167434073 0:01.8\n", + "60 3734 0.0160685592 0:02.8\n", + "80 4774 0.0167574361 0:03.5\n", + "100 6078 0.0164527805 0:04.5\n", + "120 7352 0.0163220892 0:05.4\n", + "140 8780 0.0159453303 0:06.5\n", + "160 10169 0.0157340938 0:07.5\n", + "180 11237 0.0160185103 0:08.3\n", + "200 12453 0.0160603871 0:09.2\n", + "220 14073 0.015632772 0:10.4\n", + "240 15457 0.0155269457 0:11.4\n", + "260 16782 0.0154927899 0:12.4\n", + "280 18094 0.015474743 0:13.4\n", + "300 19290 0.0155520995 0:14.3\n", + "320 20742 0.0154276348 0:15.4\n", + "340 21715 0.0156573797 0:16.1\n", + "360 23213 0.0155085512 0:17.2\n", + "380 24642 0.0154208262 0:18.2\n", + "400 25951 0.0154136642 0:19.2\n", + "420 27092 0.0155027314 0:20.0\n", + "440 28605 0.0153819262 0:21.1\n", + "460 29761 0.0154564699 0:22.0\n", + "480 30963 0.0155023738 0:22.9\n", + "500 32579 0.0153473096 0:24.1\n", + "520 33669 0.0154444741 0:24.9\n", + "540 34618 0.0155988214 0:25.6\n", + "560 35662 0.0157029892 0:26.3\n", + "580 37048 0.015655366 0:27.3\n", + "600 38963 0.0153992249 0:28.7\n", + "620 40448 0.0153283228 0:29.8\n", + "640 42540 0.0150446638 0:31.4\n", + "660 43768 0.0150795101 0:32.3\n", + "680 45169 0.0150545728 0:33.3\n", + "700 46368 0.0150966184 0:34.2\n", + "720 47499 0.0151582139 0:35.0\n", + "740 48691 0.0151978805 0:35.9\n", + "760 49616 0.0153176395 0:36.6\n", + "780 50795 0.0153558421 0:37.4\n", + "800 51940 0.0154023874 0:38.3\n", + "820 52849 0.0155159038 0:39.0\n", + "840 53995 0.015556996 0:39.8\n", + "860 54990 0.0156392071 0:40.5\n", + "880 55919 0.0157370482 0:41.2\n", + "900 57460 0.01566307 0:42.4\n", + "920 58346 0.0157680047 0:43.0\n", + "940 60000 0.0156666667 0:44.2\n", + "960 60898 0.0157640645 0:44.9\n", + "980 62112 0.0157779495 0:45.8\n", + "1000 63098 0.0158483629 0:46.5\n", + "Halting: target number of samples (1000) reached.\n", + "Done\n" + ] + } + ], + "source": [ + "np.random.seed(1)\n", + "abc = pints.ABCController(error_measure, log_prior)\n", + "\n", + "# set threshold\n", + "abc.sampler().set_threshold(1)\n", + "\n", + "# set target number of samples\n", + "abc.set_n_samples(1000)\n", + "\n", + "# log to screen\n", + "abc.set_log_to_screen(True)\n", + "\n", + "print('Running...')\n", + "samples = abc.run()\n", + "print('Done')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now plot the ABC posterior samples versus the actual value that was used to generate the data. This shows that, in this case, the parameter could be identified given the data." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAcQElEQVR4nO3de3RU9b338fcXQfCCF0hKkYDBCgpyCRJUlNMFRSuyELQiYJWCpUVBXVp1rUetXfJ4sMdTK/rYKoVTPaF6jPJQUeSxz4Nc1Fq5CBSRS5HUQglSLqFyRAQFvs8fsxM3Y0Im2ZnsYft5rTVr9vz2ZT7ZM/nml9/es8fcHRERSZYmcQcQEZGGp+IuIpJAKu4iIgmk4i4ikkAq7iIiCdQ07gAAeXl5XlhYGHcMEZFjyooVK3a5e35183KiuBcWFrJ8+fK4Y4iIHFPMbHNN8zQsIyKSQCruIiIJpOIuIpJAOTHmLiK564svvqC8vJz9+/fHHeVrq0WLFhQUFNCsWbOM11FxF5GjKi8vp2XLlhQWFmJmccf52nF3KioqKC8vp2PHjhmvp2EZETmq/fv307p1axX2mJgZrVu3rvN/TrUWdzNrYWbLzOw9M1trZv8zaO9oZkvNrMzMXjSz44P25sHjsmB+YT1+HhHJISrs8arP/s+k534A+I679wSKgEFmdhHw78Bj7n428E9gXLD8OOCfQftjwXIiItKIai3unrI3eNgsuDnwHWBW0D4DuCqYHhY8Jpg/0PRnXyQxzBr2lqmHHnqI8847jx49elBUVMTSpUuz9jP279//mP9gZUYHVM3sOGAFcDbwJPBX4GN3PxgsUg60C6bbAVsA3P2gme0BWgO70rY5HhgP0KFDh2g/xddMnH8qq/1ul3feSd1ffHGjZpGvj8WLFzN37lxWrlxJ8+bN2bVrF59//nncsXJaRgdU3f2QuxcBBcAFwLlRn9jdp7t7sbsX5+dXe2kEOVZcfLEKu2TVtm3byMvLo3nz5gDk5eVxxhln8OCDD9KnTx+6devG+PHjqfxmuf79+/OTn/yE4uJiunTpwrvvvsv3vvc9OnXqxP333w/Apk2bOPfcc7n++uvp0qULw4cPZ9++fV957nnz5tG3b1/OP/98rr32WvbuTQ1k3HPPPXTt2pUePXpw9913N9KeyFydzpZx94+BRUBf4DQzq+z5FwBbg+mtQHuAYP6pQEWDpJXc9M47X/beRbLgu9/9Llu2bKFz585MnDiRN998E4Bbb72Vd999lzVr1vDZZ58xd+7cqnWOP/54li9fzs0338ywYcN48sknWbNmDSUlJVRUpErShg0bmDhxIuvXr+eUU07hqaeeOuJ5d+3axeTJk5k/fz4rV66kuLiYKVOmUFFRwezZs1m7di2rV6+u+oORSzI5WybfzE4Lpk8ALgPWkyryw4PFxgCvBNNzgscE8xe6vqg12e67L3UTyZKTTz6ZFStWMH36dPLz8xk5ciQlJSUsWrSICy+8kO7du7Nw4ULWrl1btc7QoUMB6N69O+eddx5t27alefPmnHXWWWzZsgWA9u3bc8kllwBwww038Pbbbx/xvEuWLGHdunVccsklFBUVMWPGDDZv3sypp55KixYtGDduHC+99BInnnhiI+2JzGUy5t4WmBGMuzcBZrr7XDNbB7xgZpOBPwNPB8s/DTxrZmXAbmBUFnKLyNfMcccdR//+/enfvz/du3dn2rRprF69muXLl9O+fXsmTZp0xLnglUM4TZo0qZqufHzwYOpwYfq5HumP3Z3LLruM0tLSr+RZtmwZCxYsYNasWfz6179m4cKFDfazNoRMzpZZ7e693L2Hu3dz9weD9g/d/QJ3P9vdr3X3A0H7/uDx2cH8D7P9Q4hIsm3YsIGNGzdWPV61ahXnnHMOkBp/37t3L7Nmzapp9Rr9/e9/Z/HixQA8//zz9OvX74j5F110EX/6058oKysD4NNPP+WDDz5g79697Nmzh8GDB/PYY4/x3nvv1fdHyxpdfkBE6iSOQda9e/dy22238fHHH9O0aVPOPvtspk+fzmmnnUa3bt345je/SZ8+feq83XPOOYcnn3ySH/7wh3Tt2pUJEyYcMT8/P5+SkhKuu+46Dhw4AMDkyZNp2bIlw4YNY//+/bg7U6ZMaZCfsyFZLgyHFxcX+7F+TmljyrlTIfv3T92/8UYjJpHGsn79erp06RJ3jAa3adMmhgwZwpo1a+KOkpHqXgczW+HuxdUtr567RPf443EnEJE0Ku4SXVFR3AlE6qywsPCY6bXXh64KKdHNn5+6iUjOUM9dops8OXV/6aXx5hCRKuq5i4gkkIq7iEgCqbiLyDHh5Zdfxsz4y1/+Uuuyjz/+eLUXActUSUkJt956a73Xb+jt1IeKu4gcE0pLS+nXr1+1lwJIF7W4J4GKu0Q3bVrqJpIle/fu5e233+bpp5/mhRdeqGo/dOgQd999N926daNHjx786le/4oknnuCjjz5iwIABDBgwAEhdeKzSrFmzGDt2LACvvvoqF154Ib169eLSSy9l+/btNWY4fPgwhYWFfPzxx1VtnTp1Yvv27RltZ+zYsUdcIiGc6ZFHHqFPnz706NGDBx54oO47qBo6W0aiC67xIV8TlZ9IDhsxAiZOhH37YPDgr84fOzZ127ULhg8/cl4Gn2x+5ZVXGDRoEJ07d6Z169asWLGC3r17M336dDZt2sSqVato2rQpu3fvplWrVkyZMoVFixaRl5d31O3269ePJUuWYGb89re/5Re/+AWPPvpotcs2adKEYcOGMXv2bG688UaWLl3KmWeeSZs2beq0nXTz5s1j48aNLFu2DHdn6NChvPXWW3z729/OaP2aqLhLdK++mrq/8sp4c0hilZaWcvvttwMwatQoSktL6d27N/Pnz+fmm2+madNUKWvVqlWdtlteXs7IkSPZtm0bn3/+OR07djzq8iNHjuTBBx/kxhtv5IUXXmDkyJH12k7YvHnzmDdvHr169QJS/6Vs3LhRxV1yQGUPRcX96+FoPe0TTzz6/Ly8Ol+DaPfu3SxcuJD3338fM+PQoUOYGY888kjG2whfyjd8WeDbbruNO++8k6FDh/LGG28wadKko26nb9++lJWVsXPnTl5++eWqL+nIZDtNmzbl8OHDQGqIp/JrAt2de++9l5tuuinjnycTGnMXkZw2a9YsRo8ezebNm9m0aRNbtmyhY8eO/PGPf+Syyy5j2rRpVddn3717NwAtW7bkk08+qdpGmzZtWL9+PYcPH2b27NlV7Xv27KFdu9TXP8+YMaPWLGbG1VdfzZ133kmXLl1o3bp1xtspLCxkxYoVAMyZM4cvvvgCgMsvv5xnnnmm6uv7tm7dyo4dOzLfQTVQcReRnFZaWsrVV199RNs111xDaWkpP/rRj+jQoQM9evSgZ8+ePP/88wCMHz+eQYMGVR1QffjhhxkyZAgXX3wxbdu2rdrOpEmTuPbaa+ndu3et4/OVRo4cyXPPPVc1JJPpdn784x/z5ptv0rNnTxYvXsxJJ50EpL5C8Pvf/z59+/ale/fuDB8+/Ig/TPWlS/4eg3TJX2lMSb3k77Gmrpf8Vc9dRCSBdEBVonv22bgTiEgaFXeJrn37uBNIlrn7V748WhpPfYbPNSwj0b34YuomidSiRQsqKirqVWAkOnenoqKCFi1a1Gk99dwluqlTU/ehswckOQoKCigvL2fnzp1xR/naatGiBQUFBXVaR8VdRI6qWbNmdfrEpeQGDcuIiCSQiruISAKpuIuIJFCtxd3M2pvZIjNbZ2Zrzez2oH2SmW01s1XBbXBonXvNrMzMNpjZ5dn8ASQHzJqVuolIzsjkgOpB4C53X2lmLYEVZvZ6MO8xd/9leGEz6wqMAs4DzgDmm1lndz/UkMElh2R4TQ4RaTy19tzdfZu7rwymPwHWA+2Ossow4AV3P+DufwPKgAsaIqzkqJKS1E1EckadxtzNrBDoBSwNmm41s9Vm9oyZnR60tQO2hFYrp5o/BmY23syWm9lynT97jFNxF8k5GRd3MzsZ+D1wh7v/NzAV+BZQBGwDMvtOqYC7T3f3Yncvzs/Pr8uqIiJSi4yKu5k1I1XY/8vdXwJw9+3ufsjdDwP/wZdDL1uB8MVGCoI2ERFpJJmcLWPA08B6d58Sam8bWuxqYE0wPQcYZWbNzawj0AlY1nCRRUSkNpmcLXMJMBp438xWBW33AdeZWRHgwCbgJgB3X2tmM4F1pM60uUVnyoiINC59E9MxKOe+iWnfvtT9iSc2ahaRr7ujfROTLhwm0amoi+QcXX5AonvqqdRNRHKGirtEN3Nm6iYiOUPFXUQkgVTcRUQSSAdUI9D3BYtIrlJxlzqp7g/aouB+QBb/2OXAGbsixxQVd4lsAG/EHUFE0mjMXUQkgVTcJbK7+CV38cvaFxSRRqPiLpENYS5DmBt3DBEJUXEXEUkgFXcRkQRScRcRSSCdCimRfcYJcUcQkTQq7hLZYP4QdwQRSaNhGRGRBFJxl8ju51+5n3+NO4aIhKi4S2QDWcBAFsQdQ0RCVNxFRBJIxV1EJIFU3EVEEkinQkpkFbSOO4KIpFFxl8iG8/u4I4hIGg3LiIgkUK3F3czam9kiM1tnZmvN7PagvZWZvW5mG4P704N2M7MnzKzMzFab2fnZ/iEkXj/nXn7OvXHHEJGQTHruB4G73L0rcBFwi5l1Be4BFrh7J2BB8BjgCqBTcBsPTG3w1JJT+rKYviyOO4aIhNRa3N19m7uvDKY/AdYD7YBhwIxgsRnAVcH0MOB3nrIEOM3M2jZ4chERqVGdxtzNrBDoBSwF2rj7tmDWP4A2wXQ7YEtotfKgLX1b481suZkt37lzZx1ji4jI0WRc3M3sZOD3wB3u/t/hee7ugNflid19ursXu3txfn5+XVYVEZFaZHQqpJk1I1XY/8vdXwqat5tZW3ffFgy77AjatwLtQ6sXBG2SUOUUxB1BRNJkcraMAU8D6919SmjWHGBMMD0GeCXU/oPgrJmLgD2h4RtJoNE8x2ieizuGiIRk0nO/BBgNvG9mq4K2+4CHgZlmNg7YDIwI5r0GDAbKgH3AjQ2aWEREalVrcXf3twGrYfbAapZ34JaIueQY8hh3APATHo85iYhU0uUHJLIiVtW+kIg0Kl1+QEQkgVTcRUQSSMVdRCSBNOYukX1A57gjiEgaFXeJ7Camxx1BRNJoWEZEJIFU3CWyaYxnGuPjjiEiIRqWkcg680HcEUQkjXruIiIJpOIuIpJAKu4iIgmkMXeJbBVFcUcQkTQq7hKZrgYpkns0LCMikkAq7hLZs9zAs9wQdwwRCdGwjERWQHncEUQkjXruIiIJpOIuIpJAKu4iIgmkMXeJbDF9444gImlU3CWy+/i3uCOISBoNy4iIJJCKu0Q2i2uYxTVxxxCREA3LSGStqYg7goikqbXnbmbPmNkOM1sTaptkZlvNbFVwGxyad6+ZlZnZBjO7PFvBRUSkZpkMy5QAg6ppf8zdi4LbawBm1hUYBZwXrPOUmR3XUGFFRCQztRZ3d38L2J3h9oYBL7j7AXf/G1AGXBAhn4iI1EOUA6q3mtnqYNjm9KCtHbAltEx50CYJtoCBLGBg3DFEJKS+xX0q8C2gCNgGPFrXDZjZeDNbbmbLd+7cWc8Ykgsm8zMm87O4Y4hISL2Ku7tvd/dD7n4Y+A++HHrZCrQPLVoQtFW3jenuXuzuxfn5+fWJISIiNahXcTeztqGHVwOVZ9LMAUaZWXMz6wh0ApZFiyi57jWu4DWuiDuGiITUep67mZUC/YE8MysHHgD6m1kR4MAm4CYAd19rZjOBdcBB4BZ3P5Sd6JIrTuCzuCOISJpai7u7X1dN89NHWf4h4KEooUREJBpdfkBEJIFU3EVEEkjXlpHI5jIk7ggikkbFXSJ7lLvjjiAiaTQsIyKSQCruEtki+rOI/nHHEJEQFXcRkQRScRcRSSAVdxGRBFJxFxFJIJ0KKZHNZETcEUQkjYq7RDaViXFHEJE0GpaRyE5gHyewL+4YIhKinrtE9hqDARjAG/EGEZEq6rmLiCSQiruISAKpuIuIJJDG3OWYYBbfc7vH99wi9aXiLpGVMDbuCCKSRsVdIpuh4i6SczTmLpG1Zhet2RV3DBEJUc9dIpvFcEDnuYvkEvXcRUQSSMVdRCSBVNxFRBJIxV1EJIFqLe5m9oyZ7TCzNaG2Vmb2upltDO5PD9rNzJ4wszIzW21m52czvOSGqUxgKhPijiEiIZn03EuAQWlt9wAL3L0TsCB4DHAF0Cm4jQemNkxMyWUzGclMRsYdQ0RCai3u7v4WsDuteRgwI5ieAVwVav+dpywBTjOztg0VVnJTAVsoYEvcMUQkpL7nubdx923B9D+ANsF0Ozjit7w8aNtGGjMbT6p3T4cOHeoZQ3LBs4wGdJ67SC6JfEDV3R2o86WV3H26uxe7e3F+fn7UGCIiElLf4r69crgluN8RtG8F2oeWKwjaRESkEdW3uM8BxgTTY4BXQu0/CM6auQjYExq+ERGRRlLrmLuZlQL9gTwzKwceAB4GZprZOGAzMCJY/DVgMFAG7ANuzEJmERGpRa3F3d2vq2HWwGqWdeCWqKHk2PIod8UdQUTS6KqQEtlcrow7goik0eUHJLLObKAzG+KOISIh6rlLZNO4CdB57iK5RD13EZEEUnEXEUkgFXcRkQRScRcRSSAdUJXIJnN/3BFEJI2Ku0S2gEvjjiAiaTQsI5H1ZBU9WRV3DBEJUc9dInucOwCd5y6SS9RzFxFJIBV3EZEEUnEXEUkgFXcRkQTSAVWJ7D5+HncEEUmj4i6RLebiuCOISBoNy0hkfXmHvrwTdwwRCVHPXSL7OfcBOs9dJJeo5y4ikkAq7iIiCaTiLiKSQMf8mLtZ3AlERHLPMV/cJX538HjcEUQkjYq7RPYeRXFHEJE0kYq7mW0CPgEOAQfdvdjMWgEvAoXAJmCEu/8zWkzJZQOZD+hLO0RySUMcUB3g7kXuXhw8vgdY4O6dgAXBY0mw+5nM/UyOO4aIhGTjbJlhwIxgegZwVRaeQ0REjiJqcXdgnpmtMLPxQVsbd98WTP8DaFPdimY23syWm9nynTt3RowhIiJhUQ+o9nP3rWb2DeB1M/tLeKa7u5l5dSu6+3RgOkBxcXG1y4iISP1E6rm7+9bgfgcwG7gA2G5mbQGC+x1RQ4qISN3Uu+duZicBTdz9k2D6u8CDwBxgDPBwcP9KQwSV3HUT0+KOICJpogzLtAFmW+ojok2B5939/5rZu8BMMxsHbAZGRI8puewDzok7goikqXdxd/cPgZ7VtFcAA6OEkmPLEF4FYC5XxpxERCrpE6oS2V08Cqi4i+QSXRVSRCSBVNxFRBJIxV1EJIFU3EVEEkgHVCWy0Twbd4SsiusLYVyf25YIVNwlsnLaxx1BRNJoWEYiG8GLjODFuGOISIh67hLZBKYCMJORMScRkUrquYuIJJCKu4hIAqm4i4gkkIq7iEgC6YCqRDacWXFHEJE0Ku4SWQV5cUcQkTQalpHIxlDCGErijiEiIeq5S2Rjg8I+g7Gx5hCRL6nnLiKSQCruIiIJpGEZkRylq1FKFOq5i4gkkHruEtlgXos7goikUXGXyD7jxLgjiEgaDctIZBN4igk8FXcMEQlRz10iG8FMAKYyMeYk0hDiOpALOpjbkNRzFxFJoKwVdzMbZGYbzKzMzO7J1vOIiMhXZaW4m9lxwJPAFUBX4Doz65qN5xIRka/K1pj7BUCZu38IYGYvAMOAdVl6PhGRekvicYZsFfd2wJbQ43LgwvACZjYeGB88PGBma7KUJYo8YFfcIaqRU7kGpO7ywHImU0hO7auQXMwVe6YaimzsuWrQILki/mE5s6YZsZ0t4+7TgekAZrbc3YvjylIT5cpcLmYC5aqLXMwEylVf2TqguhVoH3pcELSJiEgjyFZxfxfoZGYdzex4YBQwJ0vPJSIiabIyLOPuB83sVuD/AccBz7j72qOsMj0bORqAcmUuFzOBctVFLmYC5aoXc30kTEQkcfQJVRGRBFJxFxFJoGx9QvWolx4ws+Zm9mIwf6mZFQbtl5nZCjN7P7j/TmidN4Jtrgpu32ikTIVm9lnoeX8TWqd3kLXMzJ4wq/sZqxFyXR/KtMrMDptZUTAv0r7KMNe3zWylmR00s+Fp88aY2cbgNibUHml/1TeTmRWZ2WIzW2tmq81sZGheiZn9LbSviuqSKUquYN6h0HPPCbV3DF7vsuD1P76xcpnZgLT31n4zuyqYF2l/ZZDpTjNbF7xOC8zszNC8rLyvouTK9nsrEndv0BupA6h/Bc4CjgfeA7qmLTMR+E0wPQp4MZjuBZwRTHcDtobWeQMojiFTIbCmhu0uAy4CDPgDcEVj5Upbpjvw14bYV3XIVQj0AH4HDA+1twI+DO5PD6ZPj7q/ImbqDHQKps8AtgGnBY9Lwss25r4K5u2tYbszgVHB9G+ACY2ZK+313A2cGHV/ZZhpQOi5JvDl72FW3lcNkCtr762ot2z03KsuPeDunwOVlx4IGwbMCKZnAQPNzNz9z+7+UdC+FjjBzJrHmammDZpZW+AUd1/iqVfyd8BVMeW6Lli3odSay903uftq4HDaupcDr7v7bnf/J/A6MKgB9le9M7n7B+6+MZj+CNgB5NfhubOSqybB6/sdUq83pF7/Bn9vZZhrOPAHd99Xx+evb6ZFoedaQuozMpC991WkXFl+b0WSjeJe3aUH2tW0jLsfBPYArdOWuQZY6e4HQm3/Gfx787M6/usVNVNHM/uzmb1pZv8SWr68lm1mO1elkUBpWlt991Wmueq6btT9FSVTFTO7gFTv7K+h5oeCf6kfq0dnImquFma23MyWVA59kHp9Pw5e7/pssyFyVRrFV99b9d1fdc00jlRP/GjrNtbvYU25qmThvRVJTh5QNbPzgH8Hbgo1X+/u3YF/CW6jGynONqCDu/cC7gSeN7NTGum5a2VmFwL73D18bZ649lVOC3p5zwI3untlb/Ve4FygD6l/+f9HI8c601MfYf8+8LiZfauRn79Gwf7qTurzKpUaZX+Z2Q1AMfBINrZfXzXlysX3VjaKeyaXHqhaxsyaAqcCFcHjAmA28AN3r/oL6O5bg/tPgOdJ/SuV9UzufsDdK4LnXkHqr3LnYPmC0Pr1ucRCpH0V+ErPKuK+yjRXXdeNur8iXdIi+IP8f4CfuvuSynZ33+YpB4D/pHH3Vfi1+pDUsZJepF7f04LXu87bbIhcgRHAbHf/IpQ3yv7KKJOZXQr8FBga+s89W++rqLmy+d6KpqEH8Ul96vVDoCNfHpw4L22ZWzjyIOHMYPq0YPnvVbPNvGC6GamxyJsbKVM+cFwwfRapF72VV38gZ3Bj7avgcZMgz1kNta8yzRVatoSvHlD9G6mDXqcH05H3V8RMxwMLgDuqWbZtcG/A48DDjbivTgeaB9N5wEaCA3nA/+bIA6oTGytXqH0JMKCh9leG7/depDpQndLas/K+aoBcWXtvRb1lZ6MwGPgg2Bk/DdoeJPUXD6BF8OYtC16Ys4L2+4FPgVWh2zeAk4AVwGpSB1r/F0HBbYRM1wTPuQpYCVwZ2mYxsCbY5q8JPvHbGLmCef2BJWnbi7yvMszVh9TY5KekepprQ+v+MMhbRurf1AbZX/XNBNwAfJH2vioK5i0E3g9yPQec3Fj7Crg4eO73gvtxoW2eFbzeZcHr37yRX8NCUh2HJmnbjLS/Msg0H9geep3mZPt9FSVXtt9bUW66/ICISALl5AFVERGJRsVdRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQS6P8Di4wp7RG+U0gAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(samples[:,0], color=\"blue\", label=\"Samples\")\n", + "plt.vlines(x=model.suggested_parameters(), linestyles='dashed', ymin=0, ymax=300, label=\"Actual value\", color=\"red\")\n", + "plt.legend()\n", + "plt.show()" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "62b8c3045b77e73a8aab814fbf01ae024ab075fc3f7014742f3a4c5a8ac43e7b" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pints/__init__.py b/pints/__init__.py index ec03b8229..51b3e3506 100644 --- a/pints/__init__.py +++ b/pints/__init__.py @@ -236,11 +236,20 @@ def version(formatted=False): from ._nested._ellipsoid import NestedEllipsoidSampler +# +# ABC +# +from ._abc import ABCSampler +from ._abc import ABCController +from ._abc._abc_rejection import RejectionABC + + # # Sampling initialising # from ._sample_initial_points import sample_initial_points + # # Transformations # diff --git a/pints/_abc/__init__.py b/pints/_abc/__init__.py new file mode 100644 index 000000000..6e2b74419 --- /dev/null +++ b/pints/_abc/__init__.py @@ -0,0 +1,364 @@ +# +# Sub-module containing ABC inference routines +# +# This file is part of PINTS (https://github.com/pints-team/pints/) which is +# released under the BSD 3-clause license. See accompanying LICENSE.md for +# copyright notice and full license details. +# +import pints +import numpy as np + + +class ABCSampler(pints.Loggable, pints.TunableMethod): + """ + Abstract base class for ABC methods. + All ABC samplers implement the :class:`pints.Loggable` and + :class:`pints.TunableMethod` interfaces. + """ + + def name(self): + """ + Returns this method's full name. + """ + raise NotImplementedError + + def ask(self): + """ + Returns a parameter vector sampled from the LogPrior. + """ + raise NotImplementedError + + def tell(self, x): + """ + Performs an iteration of the ABC algorithm, using the + parameters specified by ask. + Expects to receive x as a sequence of length at least 1. + Returns the accepted parameter values. + """ + raise NotImplementedError + + +class ABCController(object): + """ + Samples from a :class:`pints.LogPrior`. + + Properties related to the number of iterations, parallelisation, + threshold, and number of parameters to sample can be set directly on the + ``ABCController`` object. Afterwards the ABC routine can be run. + + Parameters + ---------- + error_measure + An error measure to evaluate on a problem, given a forward model, + simulated and observed data, and times + log_prior + A :class:`LogPrior` function from which parameter values are sampled + method + The class of :class:`ABCSampler` to use. If no method is specified, + :class:`RejectionABC` is used. + + Example + ------- + :: + abc = pints.ABCController(error_measure, log_prior) + abc.set_max_iterations(1000) + posterior_estimate = abc.run() + + """ + + def __init__(self, error_measure, log_prior, method=None): + + # Store function + if not isinstance(log_prior, pints.LogPrior): + raise ValueError('Given function must extend pints.LogPrior.') + self._log_prior = log_prior + + # Check error_measure + if not isinstance(error_measure, pints.ErrorMeasure): + raise ValueError('Given error_measure must extend ' + 'pints.ErrorMeasure') + self._error_measure = error_measure + + # Check if number of parameters from prior matches that of error + # measure + if self._log_prior.n_parameters() != \ + self._error_measure.n_parameters(): + raise ValueError('Number of parameters in prior must match number ' + 'of parameters in error measure.') + + # Get number of parameters + self._n_parameters = self._log_prior.n_parameters() + + # Set rejection ABC as default method + if method is None: + method = pints.RejectionABC + else: + try: + ok = issubclass(method, ABCSampler) + except TypeError: # Not a class + ok = False + if not ok: + raise ValueError('Given method must extend ABCSampler.') + + # Initialisation + + # Parallelisation + self._parallel = False + self._n_workers = 1 + + # Maximum number of iterations as a stopping criterion + self._max_iterations = 10000 + + # Maximum number of target samples to obtain + # in the estimated posterior + self._n_samples = 500 + + # The sampler object uses the prior distribution + self._sampler = method(log_prior) + + # Logging + self._log_to_screen = True + self._log_filename = None + self._log_csv = False + self.set_log_interval() + + def set_log_interval(self, iters=20, warm_up=3): + """ + Changes the frequency with which messages are logged. + + Parameters + ---------- + iters + A log message will be shown every ``iters`` iterations. + warm_up + A log message will be shown every iteration, for the first + ``warm_up`` iterations. + """ + iters = int(iters) + if iters < 1: + raise ValueError("Interval must be greater than 0.") + + warm_up = max(0, int(warm_up)) + self._message_interval = iters + self._message_warm_up = warm_up + + def set_log_to_file(self, filename=None, csv=False): + """ + Enables progress logging to file when a filename is passed in, disables + it if ``filename`` is ``False`` or ``None``. + + The argument ``csv`` can be set to ``True`` to write the file in comma + separated value (CSV) format. By default, the file contents will be + similar to the output on screen. + """ + if filename: + self._log_filename = str(filename) + self._log_csv = True if csv else False + else: + self._log_filename = None + self._log_csv = False + + def set_log_to_screen(self, enabled): + """ + Enables or disables progress logging to screen. + """ + self._log_to_screen = True if enabled else False + + def max_iterations(self): + """ + Returns the maximum iterations if this stopping criterion is set, or + ``None`` if it is not. See :meth:`set_max_iterations()`. + """ + return self._max_iterations + + def n_samples(self): + """ + Returns the target number of samples to obtain in the estimated + posterior. + """ + return self._n_samples + + def parallel(self): + """ + Returns the number of parallel worker processes this routine will be + run on, or ``False`` if parallelisation is disabled. + """ + return self._n_workers if self._parallel else False + + def run(self): + """ + Runs the ABC sampler. + """ + if self._max_iterations is None: + raise ValueError("At least one stopping criterion must be set.") + + # Iteration and evaluation counting + iteration = 0 + evaluations = 0 + accepted_count = 0 + + # Choose method to evaluate + f = self._error_measure + + # Create evaluator + if self._parallel: + n_workers = self._n_workers + evaluator = pints.ParallelEvaluator(f, n_workers=n_workers) + else: + evaluator = pints.SequentialEvaluator(f) + + # Set up progress reporting + next_message = 0 + + # Start logging + logging = self._log_to_screen or self._log_filename + if logging: + if self._log_to_screen: + print('Using ' + str(self._sampler.name())) + if self._parallel: + print('Running in parallel with ' + str(n_workers) + + ' worker processess.') + else: + print('Running in sequential mode.') + + # Set up logger + logger = pints.Logger() + if not self._log_to_screen: + logger.set_stream(None) + if self._log_filename: + logger.set_filename(self._log_filename, csv=self._log_csv) + + # Add fields to log + max_iter_guess = max(self._max_iterations or 0, 10000) + max_eval_guess = max_iter_guess + logger.add_counter('Iter.', max_value=max_iter_guess) + logger.add_counter('Eval.', max_value=max_eval_guess) + logger.add_float('Acceptance rate') + self._sampler._log_init(logger) + logger.add_time('Time m:s') + + # Start sampling + timer = pints.Timer() + running = True + + # Specifying the number of samples we want to get + # from the prior at once. It depends on whether we + # are using parallelisation and how many workers + # are being used. + if self._parallel: + n_requested_samples = self._n_workers + else: + n_requested_samples = 1 + + samples = [] + # Sample until we find an acceptable sample + while running: + accepted_vals = None + while accepted_vals is None: + # Get points from prior + xs = self._sampler.ask(n_requested_samples) + + # Simulate and get error + fxs = evaluator.evaluate(xs) + evaluations += self._n_workers + + # Tell sampler errors and get list of acceptable parameters + accepted_vals = self._sampler.tell(fxs) + + accepted_count += len(accepted_vals) + for val in accepted_vals: + samples.append(val) + + iteration += 1 + + # Log progress + if logging and iteration >= next_message: + # Log state + logger.log(iteration, evaluations, ( + accepted_count / evaluations)) + self._sampler._log_write(logger) + logger.log(timer.time()) + + # Choose next logging point + if iteration < self._message_warm_up: + next_message = iteration + 1 + else: + next_message = self._message_interval * ( + 1 + iteration // self._message_interval) + + if iteration >= self._max_iterations: + running = False + halt_message = ('Halting: Maximum number of iterations (' + + str(iteration) + ') reached. Only (' + + str(accepted_count) + ') sample were ' + + 'obtained') + elif accepted_count >= self._n_samples: + running = False + halt_message = ('Halting: target number of samples (' + + str(accepted_count) + ') reached.') + + # Log final state and show halt message + if logging: + logger.log(iteration, evaluations) + self._sampler._log_write(logger) + logger.log(timer.time()) + if self._log_to_screen: + print(halt_message) + samples = np.array(samples) + return samples + + def log_filename(self): + """ + Returns the file name in which all the logs related to the + ABC routine will be stored. + """ + return self._log_filename + + def sampler(self): + """ + Returns the underlying sampler object. + """ + return self._sampler + + def set_max_iterations(self, iterations=10000): + """ + Adds a stopping criterion, allowing the routine to halt after the + given number of ``iterations``. + + This criterion is enabled by default. To disable it, use + ``set_max_iterations(None)``. + """ + if iterations is not None: + iterations = int(iterations) + if iterations < 0: + raise ValueError( + 'Maximum number of iterations cannot be negative.') + self._max_iterations = iterations + + def set_n_samples(self, n_samples=500): + """ + Sets a target number of samples + """ + self._n_samples = n_samples + + def set_parallel(self, parallel=False): + """ + Enables/disables parallel evaluation. + + If ``parallel=True``, the method will run using a number of worker + processes equal to the detected cpu core count. The number of workers + can be set explicitly by setting ``parallel`` to an integer greater + than 0. + Parallelisation can be disabled by setting ``parallel`` to ``0`` or + ``False``. + """ + if parallel is True: + self._n_workers = pints.ParallelEvaluator.cpu_count() + self._parallel = True + + elif parallel >= 1: + self._parallel = True + self._n_workers = int(parallel) + else: + self._parallel = False + self._n_workers = 1 diff --git a/pints/_abc/_abc_rejection.py b/pints/_abc/_abc_rejection.py new file mode 100644 index 000000000..ed97a0876 --- /dev/null +++ b/pints/_abc/_abc_rejection.py @@ -0,0 +1,91 @@ +# +# ABC Rejection method +# +# This file is part of PINTS (https://github.com/pints-team/pints/) which is +# released under the BSD 3-clause license. See accompanying LICENSE.md for +# copyright notice and full license details. +# +import pints +import numpy as np + + +class RejectionABC(pints.ABCSampler): + r""" + Implements the rejection ABC algorithm as described in [1]. + + Here is a high-level description of the algorithm: + + .. math:: + \begin{align} + \theta^* &\sim p(\theta) \\ + x &\sim p(x|\theta^*) \\ + \textrm{if } s(x) < \textrm{threshold}, \textrm{then} \\ + \theta^* \textrm{ is added to list of samples} \\ + \end{align} + + In other words, the first two steps sample parameters + from the prior distribution :math:`p(\theta)` and then sample + simulated data from the sampling distribution (conditional on + the sampled parameter values), :math:`p(x|\theta^*)`. + In the end, if the error measure between our simulated data and + the original data is within the threshold, we add the sampled + parameters to the list of samples. + + References + ---------- + .. [1] "Approximate Bayesian Computation (ABC) in practice". Katalin + Csillery, Michael G.B. Blum, Oscar E. Gaggiotti, Olivier Francois + (2010) Trends in Ecology & Evolution + https://doi.org/10.1016/j.tree.2010.04.001 + + """ + def __init__(self, log_prior): + + self._log_prior = log_prior + self._threshold = 1 + self._xs = None + self._ready_for_tell = False + + def name(self): + """ See :meth:`pints.ABCSampler.name()`. """ + return 'Rejection ABC' + + def ask(self, n_samples): + """ See :meth:`ABCSampler.ask()`. """ + if self._ready_for_tell: + raise RuntimeError('Ask called before tell.') + self._xs = self._log_prior.sample(n_samples) + + self._ready_for_tell = True + return self._xs + + def tell(self, fx): + """ See :meth:`ABCSampler.tell()`. """ + if not self._ready_for_tell: + raise RuntimeError('Tell called before ask.') + self._ready_for_tell = False + + fx = pints.vector(fx) + accepted = self._xs[fx < self._threshold] + if np.sum(accepted) == 0: + return None + else: + return [self._xs.tolist() for c, x in + enumerate(accepted) if x.all()] + + def threshold(self): + """ + Returns threshold error distance that determines if a sample is + accepted (if ``error < threshold``). + """ + return self._threshold + + def set_threshold(self, threshold): + """ + Sets threshold error distance that determines if a sample is accepted + (if ``error < threshold``). + """ + x = float(threshold) + if x <= 0: + raise ValueError('Threshold must be greater than zero.') + self._threshold = threshold diff --git a/pints/tests/test_abc_controller.py b/pints/tests/test_abc_controller.py new file mode 100644 index 000000000..487895c52 --- /dev/null +++ b/pints/tests/test_abc_controller.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# +# Tests the ABC Controller. +# +# This file is part of PINTS (https://github.com/pints-team/pints/) which is +# released under the BSD 3-clause license. See accompanying LICENSE.md for +# copyright notice and full license details. +# +import pints +import pints.toy +import pints.toy.stochastic +import unittest +import numpy as np +from shared import StreamCapture + + +class TestABCController(unittest.TestCase): + """ + Tests the ABCController class. + """ + + @classmethod + def setUpClass(cls): + """ Prepare problem for tests. """ + + # Create toy model + cls.model = pints.toy.stochastic.DegradationModel() + cls.real_parameters = [0.1] + cls.times = np.linspace(0, 10, 10) + cls.values = cls.model.simulate(cls.real_parameters, cls.times) + + # Create an object (problem) with links to the model and time series + cls.problem = pints.SingleOutputProblem( + cls.model, cls.times, cls.values) + + # Create a uniform prior over both the parameters + cls.log_prior = pints.UniformLogPrior( + [0.0], + [0.3] + ) + + # Set error measure + cls.error_measure = pints.RootMeanSquaredError(cls.problem) + + def test_nparameters_error(self): + # Test that error is thrown when parameters from log prior and error + # measure do not match. + log_prior = pints.UniformLogPrior( + [0.0, 0, 0], + [0.2, 100, 1]) + + self.assertRaises(ValueError, pints.ABCController, self.error_measure, + log_prior) + + def test_error_measure_instance(self): + # Test that error is thrown when we use an error measure which is not + # an instance of ``pints.ErrorMeasure``. + # Set a log prior as the error measure to trigger the warning + wrong_error_measure = pints.UniformLogPrior( + [0.0, 0, 0], + [0.2, 100, 1]) + + self.assertRaises( + ValueError, + pints.ABCController, + wrong_error_measure, + self.log_prior) + + def test_stopping(self): + #" Test different stopping criteria. + + abc = pints.ABCController(self.error_measure, self.log_prior) + + # Test setting max iterations + maxi = abc.max_iterations() + 2 + self.assertNotEqual(maxi, abc.max_iterations()) + abc.set_max_iterations(maxi) + self.assertEqual(maxi, abc.max_iterations()) + self.assertRaisesRegex( + ValueError, + 'Maximum number of iterations cannot be negative.', + abc.set_max_iterations, -1) + + # # Test without stopping criteria + abc.set_max_iterations(None) + self.assertIsNone(abc.max_iterations()) + self.assertRaisesRegex( + ValueError, + 'At least one stopping criterion must be set.', + abc.run) + + def test_parallel(self): + # Test running ABC with parallisation. + + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + + # Test with auto-detected number of worker processes + self.assertFalse(abc.parallel()) + abc.set_parallel(True) + self.assertTrue(abc.parallel()) + self.assertEqual(abc.parallel(), pints.ParallelEvaluator.cpu_count()) + + # Test with fixed number of worker processes + abc.set_parallel(2) + self.assertEqual(abc.parallel(), 2) + + def test_logging(self): + # Tests logging to screen + + # No output + with StreamCapture() as capture: + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_max_iterations(10) + abc.set_log_to_screen(False) + abc.set_log_to_file(False) + abc.run() + self.assertEqual(capture.text(), '') + + # With output to screen + np.random.seed(1) + with StreamCapture() as capture: + pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_max_iterations(10) + abc.set_log_to_screen(True) + abc.set_log_to_file(False) + abc.run() + lines = capture.text().splitlines() + self.assertTrue(len(lines) > 0) + + # With output to screen + np.random.seed(1) + with StreamCapture() as capture: + pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_max_iterations(10) + abc.set_log_to_screen(False) + abc.set_log_to_file(True) + abc.run() + lines = capture.text().splitlines() + self.assertTrue(len(lines) == 0) + + # Invalid log interval + self.assertRaises(ValueError, abc.set_log_interval, 0) + + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_log_to_file("temp_file") + self.assertEqual(abc.log_filename(), "temp_file") + + # tests logging to screen with parallel + with StreamCapture() as capture: + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_parallel(2) + abc.set_max_iterations(10) + abc.set_log_to_screen(False) + abc.set_log_to_file(False) + abc.run() + self.assertEqual(capture.text(), '') + + def test_controller_extra(self): + # Tests various controller aspects + + self.assertRaises(ValueError, pints.ABCController, self.error_measure, + self.error_measure) + self.assertRaisesRegex( + ValueError, 'Given method must extend ABCSampler.', + pints.ABCController, self.error_measure, + self.log_prior, pints.MCMCSampler) + self.assertRaises(ValueError, pints.ABCController, self.error_measure, + pints.MCMCSampler) + self.assertRaises(ValueError, pints.ABCController, self.error_measure, + self.log_prior, 0.0) + + # test setters + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_n_samples(230) + self.assertEqual(abc.n_samples(), 230) + + sampler = abc.sampler() + pt = sampler.ask(1) + self.assertEqual(len(pt), 1) + + abc.set_parallel(False) + self.assertEqual(abc.parallel(), 0) + + with StreamCapture() as capture: + abc = pints.ABCController( + self.error_measure, self.log_prior, method=pints.RejectionABC) + abc.set_parallel(4) + abc.sampler().set_threshold(100) + abc.set_n_samples(1) + abc.run() + lines = capture.text().splitlines() + self.assertTrue(len(lines) > 0) + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/pints/tests/test_abc_rejection.py b/pints/tests/test_abc_rejection.py new file mode 100644 index 000000000..eab78786f --- /dev/null +++ b/pints/tests/test_abc_rejection.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# +# Tests the basic methods of the ABC Rejection routine. +# +# This file is part of PINTS (https://github.com/pints-team/pints/) which is +# released under the BSD 3-clause license. See accompanying LICENSE.md for +# copyright notice and full license details. +# +import pints +import pints.toy as toy +import pints.toy.stochastic +import unittest +import numpy as np + + +class TestRejectionABC(unittest.TestCase): + """ + Tests the basic methods of the ABC Rejection routine. + """ + # Set up toy model, parameter values, problem, error measure + @classmethod + def setUpClass(cls): + """ Set up problem for tests. """ + + # Create toy model + cls.model = toy.stochastic.DegradationModel() + cls.real_parameters = [0.1] + cls.times = np.linspace(0, 10, 10) + cls.values = cls.model.simulate(cls.real_parameters, cls.times) + + # Create an object (problem) with links to the model and time series + cls.problem = pints.SingleOutputProblem( + cls.model, cls.times, cls.values) + + # Create a uniform prior over both the parameters + cls.log_prior = pints.UniformLogPrior( + [0.0], + [0.3] + ) + + # Set error measure + cls.error_measure = pints.RootMeanSquaredError(cls.problem) + + def test_method(self): + + # Create abc rejection scheme + abc = pints.RejectionABC(self.log_prior) + + # Configure + n_draws = 1 + niter = 20 + + # Perform short run using ask and tell framework + samples = [] + while len(samples) < niter: + x = abc.ask(n_draws)[0] + fx = self.error_measure(x) + sample = abc.tell(fx) + while sample is None: + x = abc.ask(n_draws)[0] + fx = self.error_measure(x) + sample = abc.tell(fx) + samples.append(sample) + + samples = np.array(samples) + self.assertEqual(samples.shape[0], niter) + + def test_errors(self): + # test errors in abc rejection + abc = pints.RejectionABC(self.log_prior) + abc.ask(1) + # test two asks raises error + self.assertRaises(RuntimeError, abc.ask, 1) + # test tell with large values returns empty arrays + self.assertTrue(abc.tell(np.array([100])) is None) + # test error raised if tell called before ask + self.assertRaises(RuntimeError, abc.tell, 2.5) + + def test_setters_and_getters(self): + # test setting and getting + abc = pints.RejectionABC(self.log_prior) + self.assertEqual('Rejection ABC', abc.name()) + self.assertEqual(abc.threshold(), 1) + abc.set_threshold(2) + self.assertEqual(abc.threshold(), 2) + self.assertRaises(ValueError, abc.set_threshold, -3) + + +if __name__ == '__main__': + unittest.main()