|
1 | 1 | { |
2 | 2 | "cells": [ |
3 | 3 | { |
4 | | - "metadata": {}, |
5 | 4 | "cell_type": "markdown", |
| 5 | + "id": "372289411a2aa7b3", |
| 6 | + "metadata": {}, |
6 | 7 | "source": [ |
7 | 8 | "# Prior distributions in PEtab\n", |
8 | 9 | "\n", |
|
18 | 19 | "* *Initialization priors* can be used as a hint for the optimization algorithm. They will not enter the objective function. They are specified in the `initializationPriorType` and `initializationPriorParameters` columns of the parameter table.\n", |
19 | 20 | "\n", |
20 | 21 | "\n" |
21 | | - ], |
22 | | - "id": "372289411a2aa7b3" |
| 22 | + ] |
23 | 23 | }, |
24 | 24 | { |
| 25 | + "cell_type": "code", |
| 26 | + "execution_count": null, |
| 27 | + "id": "initial_id", |
25 | 28 | "metadata": { |
26 | 29 | "collapsed": true |
27 | 30 | }, |
28 | | - "cell_type": "code", |
| 31 | + "outputs": [], |
29 | 32 | "source": [ |
30 | 33 | "import matplotlib.pyplot as plt\n", |
31 | 34 | "import numpy as np\n", |
32 | 35 | "import seaborn as sns\n", |
33 | 36 | "\n", |
34 | 37 | "from petab.v1.C import *\n", |
| 38 | + "from petab.v1.parameters import unscale\n", |
35 | 39 | "from petab.v1.priors import Prior\n", |
36 | | - "from petab.v1.parameters import scale, unscale\n", |
37 | | - "\n", |
38 | 40 | "\n", |
39 | 41 | "sns.set_style(None)\n", |
40 | 42 | "\n", |
|
51 | 53 | " plt.tight_layout()\n", |
52 | 54 | " plt.show()\n", |
53 | 55 | "\n", |
54 | | - "def plot_single(prior: Prior, scaled: bool = False, ax=None, sample: np.array = None):\n", |
| 56 | + "\n", |
| 57 | + "def plot_single(\n", |
| 58 | + " prior: Prior, scaled: bool = False, ax=None, sample: np.array = None\n", |
| 59 | + "):\n", |
55 | 60 | " fig = None\n", |
56 | 61 | " if ax is None:\n", |
57 | 62 | " fig, ax = plt.subplots()\n", |
|
64 | 69 | " sample = unscale(sample, prior.transformation)\n", |
65 | 70 | " bounds = prior.bounds\n", |
66 | 71 | " else:\n", |
67 | | - " bounds = (prior.lb_scaled, prior.ub_scaled) if prior.bounds is not None else None\n", |
| 72 | + " bounds = (\n", |
| 73 | + " (prior.lb_scaled, prior.ub_scaled)\n", |
| 74 | + " if prior.bounds is not None\n", |
| 75 | + " else None\n", |
| 76 | + " )\n", |
68 | 77 | "\n", |
69 | 78 | " # plot pdf\n", |
70 | | - " xmin = min(sample.min(), bounds[0] if prior.bounds is not None else sample.min())\n", |
71 | | - " xmax = max(sample.max(), bounds[1] if prior.bounds is not None else sample.max())\n", |
| 79 | + " xmin = min(\n", |
| 80 | + " sample.min(), bounds[0] if prior.bounds is not None else sample.min()\n", |
| 81 | + " )\n", |
| 82 | + " xmax = max(\n", |
| 83 | + " sample.max(), bounds[1] if prior.bounds is not None else sample.max()\n", |
| 84 | + " )\n", |
72 | 85 | " padding = 0.1 * (xmax - xmin)\n", |
73 | 86 | " xmin -= padding\n", |
74 | 87 | " xmax += padding\n", |
75 | 88 | " x = np.linspace(xmin, xmax, 500)\n", |
76 | 89 | " y = prior.pdf(x, x_scaled=scaled, rescale=scaled)\n", |
77 | | - " ax.plot(x, y, color='red', label='pdf')\n", |
| 90 | + " ax.plot(x, y, color=\"red\", label=\"pdf\")\n", |
78 | 91 | "\n", |
79 | | - " sns.histplot(sample, stat='density', ax=ax, label=\"sample\")\n", |
| 92 | + " sns.histplot(sample, stat=\"density\", ax=ax, label=\"sample\")\n", |
80 | 93 | "\n", |
81 | 94 | " # plot bounds\n", |
82 | 95 | " if prior.bounds is not None:\n", |
83 | 96 | " for bound in bounds:\n", |
84 | 97 | " if bound is not None and np.isfinite(bound):\n", |
85 | | - " ax.axvline(bound, color='black', linestyle='--', label='bound')\n", |
| 98 | + " ax.axvline(bound, color=\"black\", linestyle=\"--\", label=\"bound\")\n", |
86 | 99 | "\n", |
87 | 100 | " if fig is not None:\n", |
88 | 101 | " ax.set_title(str(prior))\n", |
89 | 102 | "\n", |
90 | 103 | " if scaled:\n", |
91 | | - " ax.set_xlabel(f'Parameter value on parameter scale ({prior.transformation})')\n", |
| 104 | + " ax.set_xlabel(\n", |
| 105 | + " f\"Parameter value on parameter scale ({prior.transformation})\"\n", |
| 106 | + " )\n", |
92 | 107 | " ax.set_ylabel(\"Rescaled density\")\n", |
93 | 108 | " else:\n", |
94 | | - " ax.set_xlabel('Parameter value')\n", |
| 109 | + " ax.set_xlabel(\"Parameter value\")\n", |
95 | 110 | "\n", |
96 | 111 | " ax.grid(False)\n", |
97 | 112 | " handles, labels = ax.get_legend_handles_labels()\n", |
98 | | - " unique_labels = dict(zip(labels, handles))\n", |
| 113 | + " unique_labels = dict(zip(labels, handles, strict=False))\n", |
99 | 114 | " ax.legend(unique_labels.values(), unique_labels.keys())\n", |
100 | 115 | "\n", |
101 | 116 | " if ax is None:\n", |
102 | | - " plt.show()\n" |
103 | | - ], |
104 | | - "id": "initial_id", |
105 | | - "outputs": [], |
106 | | - "execution_count": null |
| 117 | + " plt.show()" |
| 118 | + ] |
107 | 119 | }, |
108 | 120 | { |
109 | | - "metadata": {}, |
110 | 121 | "cell_type": "markdown", |
111 | | - "source": "The basic distributions are the uniform, normal, Laplace, log-normal, and log-laplace distributions:\n", |
112 | | - "id": "db36a4a93622ccb8" |
| 122 | + "id": "db36a4a93622ccb8", |
| 123 | + "metadata": {}, |
| 124 | + "source": "The basic distributions are the uniform, normal, Laplace, log-normal, and log-laplace distributions:\n" |
113 | 125 | }, |
114 | 126 | { |
115 | | - "metadata": {}, |
116 | 127 | "cell_type": "code", |
| 128 | + "execution_count": null, |
| 129 | + "id": "4f09e50a3db06d9f", |
| 130 | + "metadata": {}, |
| 131 | + "outputs": [], |
117 | 132 | "source": [ |
118 | 133 | "plot_single(Prior(UNIFORM, (0, 1)))\n", |
119 | 134 | "plot_single(Prior(NORMAL, (0, 1)))\n", |
120 | 135 | "plot_single(Prior(LAPLACE, (0, 1)))\n", |
121 | 136 | "plot_single(Prior(LOG_NORMAL, (0, 1)))\n", |
122 | 137 | "plot_single(Prior(LOG_LAPLACE, (1, 0.5)))" |
123 | | - ], |
124 | | - "id": "4f09e50a3db06d9f", |
125 | | - "outputs": [], |
126 | | - "execution_count": null |
| 138 | + ] |
127 | 139 | }, |
128 | 140 | { |
129 | | - "metadata": {}, |
130 | 141 | "cell_type": "markdown", |
131 | | - "source": "If a parameter scale is specified (`parameterScale=lin|log|log10`) and the chosen distribution is not a `parameterScale*`-type distribution, then the distribution parameters are taken as is, i.e., the `parameterScale` is not applied to the distribution parameters. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n", |
132 | | - "id": "dab4b2d1e0f312d8" |
| 142 | + "id": "dab4b2d1e0f312d8", |
| 143 | + "metadata": {}, |
| 144 | + "source": "If a parameter scale is specified (`parameterScale=lin|log|log10`) and the chosen distribution is not a `parameterScale*`-type distribution, then the distribution parameters are taken as is, i.e., the `parameterScale` is not applied to the distribution parameters. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n" |
133 | 145 | }, |
134 | 146 | { |
135 | | - "metadata": {}, |
136 | 147 | "cell_type": "code", |
| 148 | + "execution_count": null, |
| 149 | + "id": "f6192c226f179ef9", |
| 150 | + "metadata": {}, |
| 151 | + "outputs": [], |
137 | 152 | "source": [ |
138 | 153 | "plot(Prior(NORMAL, (10, 2), transformation=LIN))\n", |
139 | 154 | "plot(Prior(NORMAL, (10, 2), transformation=LOG))\n", |
140 | 155 | "\n", |
141 | | - "# Note that the log-normal distribution is different from a log-transformed normal distribution:\n", |
| 156 | + "# Note that the log-normal distribution is different\n", |
| 157 | + "# from a log-transformed normal distribution:\n", |
142 | 158 | "plot(Prior(LOG_NORMAL, (10, 2), transformation=LIN))" |
143 | | - ], |
144 | | - "id": "f6192c226f179ef9", |
145 | | - "outputs": [], |
146 | | - "execution_count": null |
| 159 | + ] |
147 | 160 | }, |
148 | 161 | { |
149 | | - "metadata": {}, |
150 | 162 | "cell_type": "markdown", |
151 | | - "source": "On the log-transformed parameter scale, `Log*` and `parameterScale*` distributions are equivalent:", |
152 | | - "id": "4281ed48859e6431" |
| 163 | + "id": "4281ed48859e6431", |
| 164 | + "metadata": {}, |
| 165 | + "source": "On the log-transformed parameter scale, `Log*` and `parameterScale*` distributions are equivalent:" |
153 | 166 | }, |
154 | 167 | { |
155 | | - "metadata": {}, |
156 | 168 | "cell_type": "code", |
| 169 | + "execution_count": null, |
| 170 | + "id": "34c95268e8921070", |
| 171 | + "metadata": {}, |
| 172 | + "outputs": [], |
157 | 173 | "source": [ |
158 | 174 | "plot(Prior(LOG_NORMAL, (10, 2), transformation=LOG))\n", |
159 | 175 | "plot(Prior(PARAMETER_SCALE_NORMAL, (10, 2)))" |
160 | | - ], |
161 | | - "id": "34c95268e8921070", |
162 | | - "outputs": [], |
163 | | - "execution_count": null |
| 176 | + ] |
164 | 177 | }, |
165 | 178 | { |
166 | | - "metadata": {}, |
167 | 179 | "cell_type": "markdown", |
168 | | - "source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameters are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`.", |
169 | | - "id": "263c9fd31156a4d5" |
| 180 | + "id": "263c9fd31156a4d5", |
| 181 | + "metadata": {}, |
| 182 | + "source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameters are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`." |
170 | 183 | }, |
171 | 184 | { |
172 | | - "metadata": {}, |
173 | 185 | "cell_type": "code", |
| 186 | + "execution_count": null, |
| 187 | + "id": "5ca940bc24312fc6", |
| 188 | + "metadata": {}, |
| 189 | + "outputs": [], |
174 | 190 | "source": [ |
175 | 191 | "# different, because transformation!=LIN\n", |
176 | 192 | "plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n", |
|
179 | 195 | "# same, because transformation=LIN\n", |
180 | 196 | "plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n", |
181 | 197 | "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))" |
182 | | - ], |
183 | | - "id": "5ca940bc24312fc6", |
184 | | - "outputs": [], |
185 | | - "execution_count": null |
| 198 | + ] |
186 | 199 | }, |
187 | 200 | { |
188 | | - "metadata": {}, |
189 | 201 | "cell_type": "markdown", |
190 | | - "source": "The given distributions are truncated at the bounds defined in the parameter table:", |
191 | | - "id": "b1a8b17d765db826" |
| 202 | + "id": "b1a8b17d765db826", |
| 203 | + "metadata": {}, |
| 204 | + "source": "The given distributions are truncated at the bounds defined in the parameter table:" |
192 | 205 | }, |
193 | 206 | { |
194 | | - "metadata": {}, |
195 | 207 | "cell_type": "code", |
| 208 | + "execution_count": null, |
| 209 | + "id": "4ac42b1eed759bdd", |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [], |
196 | 212 | "source": [ |
197 | 213 | "plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n", |
198 | 214 | "plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n", |
199 | 215 | "plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n", |
200 | 216 | "plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n", |
201 | | - "plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))" |
202 | | - ], |
203 | | - "id": "4ac42b1eed759bdd", |
204 | | - "outputs": [], |
205 | | - "execution_count": null |
| 217 | + "plot(\n", |
| 218 | + " Prior(\n", |
| 219 | + " PARAMETER_SCALE_UNIFORM,\n", |
| 220 | + " (-3, 1),\n", |
| 221 | + " bounds=(1e-2, 1),\n", |
| 222 | + " transformation=LOG10,\n", |
| 223 | + " )\n", |
| 224 | + ")" |
| 225 | + ] |
206 | 226 | }, |
207 | 227 | { |
208 | | - "metadata": {}, |
209 | 228 | "cell_type": "markdown", |
210 | | - "source": "Further distribution examples:", |
211 | | - "id": "45ffce1341483f24" |
| 229 | + "id": "45ffce1341483f24", |
| 230 | + "metadata": {}, |
| 231 | + "source": "Further distribution examples:" |
212 | 232 | }, |
213 | 233 | { |
214 | | - "metadata": {}, |
215 | 234 | "cell_type": "code", |
| 235 | + "execution_count": null, |
| 236 | + "id": "581e1ac431860419", |
| 237 | + "metadata": {}, |
| 238 | + "outputs": [], |
216 | 239 | "source": [ |
217 | 240 | "plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n", |
218 | | - "plot(Prior(PARAMETER_SCALE_NORMAL, (2, 1), bounds=(10**0, 10**3), transformation=\"log10\"))\n", |
| 241 | + "plot(\n", |
| 242 | + " Prior(\n", |
| 243 | + " PARAMETER_SCALE_NORMAL,\n", |
| 244 | + " (2, 1),\n", |
| 245 | + " bounds=(10**0, 10**3),\n", |
| 246 | + " transformation=\"log10\",\n", |
| 247 | + " )\n", |
| 248 | + ")\n", |
219 | 249 | "plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n", |
220 | 250 | "plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n", |
221 | 251 | "plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))" |
222 | | - ], |
223 | | - "id": "581e1ac431860419", |
224 | | - "outputs": [], |
225 | | - "execution_count": null |
| 252 | + ] |
226 | 253 | } |
227 | 254 | ], |
228 | 255 | "metadata": { |
|
0 commit comments