|
32 | 32 | "import seaborn as sns\n", |
33 | 33 | "\n", |
34 | 34 | "from petab.v1.C import *\n", |
35 | | - "from petab.v1.distributions import *\n", |
| 35 | + "from petab.v1.priors import Prior\n", |
36 | 36 | "\n", |
37 | 37 | "sns.set_style(None)\n", |
38 | 38 | "\n", |
39 | 39 | "\n", |
40 | | - "def plot(distr: Distribution, ax=None):\n", |
| 40 | + "def plot(prior: Prior, ax=None):\n", |
41 | 41 | " \"\"\"Visualize a distribution.\"\"\"\n", |
42 | 42 | " if ax is None:\n", |
43 | 43 | " fig, ax = plt.subplots()\n", |
44 | 44 | "\n", |
45 | | - " sample = distr.sample(10000)\n", |
| 45 | + " sample = prior.sample(10000)\n", |
46 | 46 | "\n", |
47 | 47 | " # pdf\n", |
48 | | - " xmin = min(sample.min(), distr.lb_scaled if distr.bounds is not None else sample.min())\n", |
49 | | - " xmax = max(sample.max(), distr.ub_scaled if distr.bounds is not None else sample.max())\n", |
| 48 | + " xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n", |
| 49 | + " xmax = max(sample.max(), prior.ub_scaled if prior.bounds is not None else sample.max())\n", |
50 | 50 | " x = np.linspace(xmin, xmax, 500)\n", |
51 | | - " y = distr.pdf(x)\n", |
| 51 | + " y = prior.pdf(x)\n", |
52 | 52 | " ax.plot(x, y, color='red', label='pdf')\n", |
53 | 53 | "\n", |
54 | 54 | " sns.histplot(sample, stat='density', ax=ax, label=\"sample\")\n", |
55 | 55 | "\n", |
56 | 56 | " # bounds\n", |
57 | | - " if distr.bounds is not None:\n", |
58 | | - " for bound in (distr.lb_scaled, distr.ub_scaled):\n", |
| 57 | + " if prior.bounds is not None:\n", |
| 58 | + " for bound in (prior.lb_scaled, prior.ub_scaled):\n", |
59 | 59 | " if bound is not None and np.isfinite(bound):\n", |
60 | 60 | " ax.axvline(bound, color='black', linestyle='--', label='bound')\n", |
61 | 61 | "\n", |
62 | | - " ax.set_title(str(distr))\n", |
| 62 | + " ax.set_title(str(prior))\n", |
63 | 63 | " ax.set_xlabel('Parameter value on the parameter scale')\n", |
64 | 64 | " ax.grid(False)\n", |
65 | 65 | " handles, labels = ax.get_legend_handles_labels()\n", |
|
81 | 81 | "metadata": {}, |
82 | 82 | "cell_type": "code", |
83 | 83 | "source": [ |
84 | | - "plot(Uniform(0, 1))\n", |
85 | | - "plot(Normal(0, 1))\n", |
86 | | - "plot(Laplace(0, 1))\n", |
87 | | - "plot(LogNormal(0, 1))\n", |
88 | | - "plot(LogLaplace(1, 0.5))" |
| 84 | + "plot(Prior(UNIFORM, (0, 1)))\n", |
| 85 | + "plot(Prior(NORMAL, (0, 1)))\n", |
| 86 | + "plot(Prior(LAPLACE, (0, 1)))\n", |
| 87 | + "plot(Prior(LOG_NORMAL, (0, 1)))\n", |
| 88 | + "plot(Prior(LOG_LAPLACE, (1, 0.5)))" |
89 | 89 | ], |
90 | 90 | "id": "4f09e50a3db06d9f", |
91 | 91 | "outputs": [], |
|
101 | 101 | "metadata": {}, |
102 | 102 | "cell_type": "code", |
103 | 103 | "source": [ |
104 | | - "plot(Normal(10, 2, transformation=LIN))\n", |
105 | | - "plot(Normal(10, 2, transformation=LOG))\n", |
| 104 | + "plot(Prior(NORMAL, (10, 2), transformation=LIN))\n", |
| 105 | + "plot(Prior(NORMAL, (10, 2), transformation=LOG))\n", |
| 106 | + "\n", |
106 | 107 | "# Note that the log-normal distribution is different from a log-transformed normal distribution:\n", |
107 | | - "plot(LogNormal(10, 2, transformation=LIN))" |
| 108 | + "plot(Prior(LOG_NORMAL, (10, 2), transformation=LIN))" |
108 | 109 | ], |
109 | 110 | "id": "f6192c226f179ef9", |
110 | 111 | "outputs": [], |
|
120 | 121 | "metadata": {}, |
121 | 122 | "cell_type": "code", |
122 | 123 | "source": [ |
123 | | - "plot(LogNormal(10, 2, transformation=LOG))\n", |
124 | | - "plot(ParameterScaleNormal(10, 2))" |
| 124 | + "plot(Prior(LOG_NORMAL, (10, 2), transformation=LOG))\n", |
| 125 | + "plot(Prior(PARAMETER_SCALE_NORMAL, (10, 2)))" |
125 | 126 | ], |
126 | 127 | "id": "34c95268e8921070", |
127 | 128 | "outputs": [], |
|
137 | 138 | "metadata": {}, |
138 | 139 | "cell_type": "code", |
139 | 140 | "source": [ |
140 | | - "plot(Uniform(0, 1, transformation=LOG10))\n", |
141 | | - "plot(ParameterScaleUniform(0, 1, transformation=LOG10))\n", |
| 141 | + "plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n", |
| 142 | + "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n", |
142 | 143 | "\n", |
143 | | - "plot(Uniform(0, 1, transformation=LIN))\n", |
144 | | - "plot(ParameterScaleUniform(0, 1, transformation=LIN))\n" |
| 144 | + "plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n", |
| 145 | + "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n" |
145 | 146 | ], |
146 | 147 | "id": "5ca940bc24312fc6", |
147 | 148 | "outputs": [], |
|
157 | 158 | "metadata": {}, |
158 | 159 | "cell_type": "code", |
159 | 160 | "source": [ |
160 | | - "plot(Normal(0, 1, bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n", |
161 | | - "plot(Uniform(0, 1, bounds=(0.1, 0.9))) # significant clipping-bias" |
| 161 | + "plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n", |
| 162 | + "plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias" |
162 | 163 | ], |
163 | 164 | "id": "4ac42b1eed759bdd", |
164 | 165 | "outputs": [], |
|
174 | 175 | "metadata": {}, |
175 | 176 | "cell_type": "code", |
176 | 177 | "source": [ |
177 | | - "plot(Normal(10, 1, bounds=(6, 14), transformation=\"log10\"))\n", |
178 | | - "plot(ParameterScaleNormal(10, 1, bounds=(10**6, 10**14), transformation=\"log10\"))\n" |
| 178 | + "plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n", |
| 179 | + "plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n", |
| 180 | + "plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n", |
| 181 | + "\n" |
179 | 182 | ], |
180 | 183 | "id": "581e1ac431860419", |
181 | 184 | "outputs": [], |
|
185 | 188 | "metadata": {}, |
186 | 189 | "cell_type": "code", |
187 | 190 | "source": "", |
188 | | - "id": "802a64be56a6c94f", |
| 191 | + "id": "633733651bbc3ef0", |
189 | 192 | "outputs": [], |
190 | 193 | "execution_count": null |
191 | 194 | } |
|
0 commit comments