Skip to content

Commit 2484a7f

Browse files
committed
prior always on linear
1 parent 155853f commit 2484a7f

File tree

6 files changed

+219
-77
lines changed

6 files changed

+219
-77
lines changed

doc/example/distributions.ipynb

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,42 +33,73 @@
3333
"\n",
3434
"from petab.v1.C import *\n",
3535
"from petab.v1.priors import Prior\n",
36+
"from petab.v1.parameters import scale, unscale\n",
37+
"\n",
3638
"\n",
3739
"sns.set_style(None)\n",
3840
"\n",
3941
"\n",
40-
"def plot(prior: Prior, ax=None):\n",
42+
"def plot(prior: Prior):\n",
4143
" \"\"\"Visualize a distribution.\"\"\"\n",
44+
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
45+
" sample = prior.sample(20_000, x_scaled=True)\n",
46+
"\n",
47+
" fig.suptitle(str(prior))\n",
48+
"\n",
49+
" plot_single(prior, ax=ax1, sample=sample, scaled=False)\n",
50+
" plot_single(prior, ax=ax2, sample=sample, scaled=True)\n",
51+
" plt.tight_layout()\n",
52+
" plt.show()\n",
53+
"\n",
54+
"def plot_single(prior: Prior, scaled: bool = False, ax=None, sample: np.array = None):\n",
55+
" fig = None\n",
4256
" if ax is None:\n",
4357
" fig, ax = plt.subplots()\n",
4458
"\n",
45-
" sample = prior.sample(20_000)\n",
59+
" if sample is None:\n",
60+
" sample = prior.sample(20_000)\n",
61+
"\n",
62+
" # assuming scaled sample\n",
63+
" if not scaled:\n",
64+
" sample = unscale(sample, prior.transformation)\n",
65+
" bounds = prior.bounds\n",
66+
" else:\n",
67+
" bounds = (prior.lb_scaled, prior.ub_scaled) if prior.bounds is not None else None\n",
4668
"\n",
47-
" # pdf\n",
48-
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n",
49-
" xmax = max(sample.max(), prior.ub_scaled if prior.bounds is not None else sample.max())\n",
69+
" # plot pdf\n",
70+
" xmin = min(sample.min(), bounds[0] if prior.bounds is not None else sample.min())\n",
71+
" xmax = max(sample.max(), bounds[1] if prior.bounds is not None else sample.max())\n",
5072
" padding = 0.1 * (xmax - xmin)\n",
5173
" xmin -= padding\n",
5274
" xmax += padding\n",
5375
" x = np.linspace(xmin, xmax, 500)\n",
54-
" y = prior.pdf(x)\n",
76+
" y = prior.pdf(x, x_scaled=scaled, rescale=scaled)\n",
5577
" ax.plot(x, y, color='red', label='pdf')\n",
5678
"\n",
5779
" sns.histplot(sample, stat='density', ax=ax, label=\"sample\")\n",
5880
"\n",
59-
" # bounds\n",
81+
" # plot bounds\n",
6082
" if prior.bounds is not None:\n",
61-
" for bound in (prior.lb_scaled, prior.ub_scaled):\n",
83+
" for bound in bounds:\n",
6284
" if bound is not None and np.isfinite(bound):\n",
6385
" ax.axvline(bound, color='black', linestyle='--', label='bound')\n",
6486
"\n",
65-
" ax.set_title(str(prior))\n",
66-
" ax.set_xlabel('Parameter value on the parameter scale')\n",
87+
" if fig is not None:\n",
88+
" ax.set_title(str(prior))\n",
89+
"\n",
90+
" if scaled:\n",
91+
" ax.set_xlabel(f'Parameter value on parameter scale ({prior.transformation})')\n",
92+
" ax.set_ylabel(\"Rescaled density\")\n",
93+
" else:\n",
94+
" ax.set_xlabel('Parameter value')\n",
95+
"\n",
6796
" ax.grid(False)\n",
6897
" handles, labels = ax.get_legend_handles_labels()\n",
6998
" unique_labels = dict(zip(labels, handles))\n",
7099
" ax.legend(unique_labels.values(), unique_labels.keys())\n",
71-
" plt.show()"
100+
"\n",
101+
" if ax is None:\n",
102+
" plt.show()\n"
72103
],
73104
"id": "initial_id",
74105
"outputs": [],
@@ -84,11 +115,11 @@
84115
"metadata": {},
85116
"cell_type": "code",
86117
"source": [
87-
"plot(Prior(UNIFORM, (0, 1)))\n",
88-
"plot(Prior(NORMAL, (0, 1)))\n",
89-
"plot(Prior(LAPLACE, (0, 1)))\n",
90-
"plot(Prior(LOG_NORMAL, (0, 1)))\n",
91-
"plot(Prior(LOG_LAPLACE, (1, 0.5)))"
118+
"plot_single(Prior(UNIFORM, (0, 1)))\n",
119+
"plot_single(Prior(NORMAL, (0, 1)))\n",
120+
"plot_single(Prior(LAPLACE, (0, 1)))\n",
121+
"plot_single(Prior(LOG_NORMAL, (0, 1)))\n",
122+
"plot_single(Prior(LOG_LAPLACE, (1, 0.5)))"
92123
],
93124
"id": "4f09e50a3db06d9f",
94125
"outputs": [],
@@ -97,7 +128,7 @@
97128
{
98129
"metadata": {},
99130
"cell_type": "markdown",
100-
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10` not a `parameterScale*`-type distribution), the sample is transformed accordingly (but not the distribution parameters):\n",
131+
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10`) and the chosen distribution is not a `parameterScale*`-type distribution, then the distribution parameters are taken as is, i.e., the `parameterScale` is not applied to the distribution parameters. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n",
101132
"id": "dab4b2d1e0f312d8"
102133
},
103134
{
@@ -134,7 +165,7 @@
134165
{
135166
"metadata": {},
136167
"cell_type": "markdown",
137-
"source": "Prior distributions can also be defined on the parameter scale by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, 1) the distribution parameter are interpreted on the transformed parameter scale, and 2) a sample from the given distribution is used directly, without applying any transformation according to `parameterScale` (this implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`):",
168+
"source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameter are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`.",
138169
"id": "263c9fd31156a4d5"
139170
},
140171
{
@@ -167,7 +198,7 @@
167198
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n",
168199
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n",
169200
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n",
170-
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))\n"
201+
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))"
171202
],
172203
"id": "4ac42b1eed759bdd",
173204
"outputs": [],
@@ -184,7 +215,7 @@
184215
"cell_type": "code",
185216
"source": [
186217
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n",
187-
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**9, 10**14), transformation=\"log10\"))\n",
218+
"plot(Prior(PARAMETER_SCALE_NORMAL, (2, 1), bounds=(10**0, 10**3), transformation=\"log10\"))\n",
188219
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
189220
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n",
190221
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))"

petab/v1/C.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@
207207
PARAMETER_SCALE_LAPLACE,
208208
]
209209

210+
#: parameterScale*-type prior distributions
211+
PARAMETER_SCALE_PRIOR_TYPES = [
212+
PARAMETER_SCALE_UNIFORM,
213+
PARAMETER_SCALE_NORMAL,
214+
PARAMETER_SCALE_LAPLACE,
215+
]
216+
210217
#: Supported noise distributions
211218
NOISE_MODELS = [NORMAL, LAPLACE]
212219

petab/v1/distributions.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def _pdf_transformed_untruncated(self, x) -> np.ndarray | float:
168168

169169
# handle the log transformation; see also:
170170
# https://en.wikipedia.org/wiki/Probability_density_function#Scalar_to_scalar
171-
chain_rule_factor = (
172-
(1 / (x * np.log(self._logbase))) if self._logbase else 1
173-
)
171+
with np.errstate(invalid="ignore", divide="ignore"):
172+
chain_rule_factor = (
173+
(1 / (x * np.log(self._logbase))) if self._logbase else 1
174+
)
174175

175-
with np.errstate(invalid="ignore"):
176176
return np.where(
177177
x > 0,
178178
self._pdf_untransformed_untruncated(self._log(x))
@@ -242,6 +242,19 @@ def _ppf_transformed_untruncated(self, q) -> np.ndarray | float:
242242
"""
243243
return self._exp(self._ppf_untransformed_untruncated(q))
244244

245+
def ppf(self, q) -> np.ndarray | float:
246+
"""Percent point function at q.
247+
248+
:param q: The quantile at which to evaluate the PPF.
249+
:return: The value of the PPF at ``q``.
250+
"""
251+
if self._trunc is None:
252+
return self._ppf_transformed_untruncated(q)
253+
254+
# Adjust quantiles to account for truncation
255+
adjusted_q = self._cd_low + q * (self._cd_high - self._cd_low)
256+
return self._ppf_transformed_untruncated(adjusted_q)
257+
245258
def _inverse_transform_sample(self, shape) -> np.ndarray | float:
246259
"""Generate an inverse transform sample from the transformed and
247260
truncated distribution.

petab/v1/priors.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -161,24 +161,31 @@ def type(self) -> str:
161161

162162
@property
163163
def parameters(self) -> tuple:
164+
"""The parameters of the distribution."""
164165
return self._parameters
165166

166167
@property
167168
def bounds(self) -> tuple[float, float] | None:
169+
"""The non-scaled bounds of the distribution."""
168170
return self._bounds
169171

170172
@property
171173
def transformation(self) -> str:
174+
"""The `parameterScale`."""
172175
return self._transformation
173176

174-
def sample(self, shape=None) -> np.ndarray:
177+
def sample(self, shape=None, x_scaled=False) -> np.ndarray | float:
175178
"""Sample from the distribution.
176179
177180
:param shape: The shape of the sample.
181+
:param x_scaled: Whether the sample should be on the parameter scale.
178182
:return: A sample from the distribution.
179183
"""
180184
raw_sample = self.distribution.sample(shape)
181-
return self._scale_sample(raw_sample)
185+
if x_scaled:
186+
return self._scale_sample(raw_sample)
187+
else:
188+
return raw_sample
182189

183190
def _scale_sample(self, sample):
184191
"""Scale the sample to the parameter space"""
@@ -196,14 +203,8 @@ def ub_scaled(self) -> float:
196203
"""The upper bound on the parameter scale."""
197204
return scale(self.bounds[1], self.transformation)
198205

199-
def pdf(self, x) -> np.ndarray | float:
200-
"""Probability density function at x.
201-
202-
:param x: The value at which to evaluate the PDF.
203-
``x`` is assumed to be on the parameter scale.
204-
:return: The value of the PDF at ``x``. ``x`` is assumed to be on the
205-
parameter scale.
206-
"""
206+
def _chain_rule_coeff(self, x) -> np.ndarray | float:
207+
"""The chain rule coefficient for the transformation at x."""
207208
x = unscale(x, self.transformation)
208209

209210
# scale the PDF to the parameter scale
@@ -216,36 +217,50 @@ def pdf(self, x) -> np.ndarray | float:
216217
else:
217218
raise ValueError(f"Unknown transformation: {self.transformation}")
218219

219-
return self.distribution.pdf(x) * coeff
220+
return coeff
221+
222+
def pdf(
223+
self, x, x_scaled: bool = False, rescale=False
224+
) -> np.ndarray | float:
225+
"""Probability density function at x.
226+
227+
This accounts for truncation, independent of the `bounds_truncate`
228+
parameter.
220229
221-
def neglogprior(self, x) -> np.ndarray | float:
230+
:param x: The value at which to evaluate the PDF.
231+
``x`` is assumed to be on the parameter scale.
232+
:param x_scaled: Whether ``x`` is on the parameter scale.
233+
:param rescale: Whether to rescale the PDF to integrate to 1 on the
234+
parameter scale. Only used if ``x_scaled`` is ``True``.
235+
:return: The value of the PDF at ``x``.
236+
"""
237+
if x_scaled:
238+
coeff = self._chain_rule_coeff(x) if rescale else 1
239+
x = unscale(x, self.transformation)
240+
return self.distribution.pdf(x) * coeff
241+
242+
return self.distribution.pdf(x)
243+
244+
def neglogprior(
245+
self, x: np.array | float, x_scaled: bool = False
246+
) -> np.ndarray | float:
222247
"""Negative log-prior at x.
223248
224249
:param x: The value at which to evaluate the negative log-prior.
225-
``x`` is assumed to be on the parameter scale.
250+
:param x_scaled: Whether ``x`` is on the parameter scale.
251+
Note that the prior is always evaluated on the non-scaled
252+
parameters.
226253
:return: The negative log-prior at ``x``.
227254
"""
228-
# FIXME: the prior is always defined on linear scale
229255
if self._bounds_truncate:
230256
# the truncation is handled by the distribution
231-
return -np.log(self.pdf(x))
257+
# the prior is always evaluated on the non-scaled parameters
258+
return -np.log(self.pdf(x, x_scaled=x_scaled, rescale=False))
232259

233260
# we want to evaluate the prior on the untruncated distribution
234-
x = unscale(x, self.transformation)
235-
236-
# scale the PDF to the parameter scale
237-
if self.transformation == C.LIN:
238-
coeff = 1
239-
elif self.transformation == C.LOG10:
240-
coeff = x * np.log(10)
241-
elif self.transformation == C.LOG:
242-
coeff = x
243-
else:
244-
raise ValueError(f"Unknown transformation: {self.transformation}")
245-
246-
return -np.log(
247-
self.distribution._pdf_transformed_untruncated(x) * coeff
248-
)
261+
if x_scaled:
262+
x = unscale(x, self.transformation)
263+
return -np.log(self.distribution._pdf_transformed_untruncated(x))
249264

250265
@staticmethod
251266
def from_par_dict(
@@ -339,6 +354,7 @@ def priors_to_measurements(problem: Problem):
339354
return new_problem
340355

341356
def scaled_observable_formula(parameter_id, parameter_scale):
357+
# The location parameter of the prior
342358
if parameter_scale == LIN:
343359
return parameter_id
344360
if parameter_scale == LOG:
@@ -367,6 +383,12 @@ def scaled_observable_formula(parameter_id, parameter_scale):
367383
# offset
368384
raise NotImplementedError("Uniform priors are not supported.")
369385

386+
if prior_type not in (C.NORMAL, C.LAPLACE):
387+
# we can't (easily) handle parameterScale* priors or log*-priors
388+
raise NotImplementedError(
389+
f"Objective prior type {prior_type} is not implemented."
390+
)
391+
370392
parameter_id = row.name
371393
prior_parameters = tuple(
372394
map(
@@ -391,7 +413,9 @@ def scaled_observable_formula(parameter_id, parameter_scale):
391413
OBSERVABLE_ID: new_obs_id,
392414
OBSERVABLE_FORMULA: scaled_observable_formula(
393415
parameter_id,
394-
parameter_scale if "parameterScale" in prior_type else LIN,
416+
parameter_scale
417+
if prior_type in C.PARAMETER_SCALE_PRIOR_TYPES
418+
else LIN,
395419
),
396420
NOISE_FORMULA: f"noiseParameter1_{new_obs_id}",
397421
}
@@ -400,12 +424,13 @@ def scaled_observable_formula(parameter_id, parameter_scale):
400424
elif OBSERVABLE_TRANSFORMATION in new_problem.observable_df:
401425
# only set default if the column is already present
402426
new_observable[OBSERVABLE_TRANSFORMATION] = LIN
403-
427+
# type of the underlying distribution
404428
if prior_type in (NORMAL, PARAMETER_SCALE_NORMAL, LOG_NORMAL):
405429
new_observable[NOISE_DISTRIBUTION] = NORMAL
406430
elif prior_type in (LAPLACE, PARAMETER_SCALE_LAPLACE, LOG_LAPLACE):
407431
new_observable[NOISE_DISTRIBUTION] = LAPLACE
408432
else:
433+
# we can't (easily) handle uniform priors in PEtab v1
409434
raise NotImplementedError(
410435
f"Objective prior type {prior_type} is not implemented."
411436
)

petab/v1/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def sample_parameter_startpoints(
8080
[
8181
Prior.from_par_dict(
8282
row, type_="initialization", bounds_truncate=True
83-
).sample(n_starts)
83+
).sample(n_starts, x_scaled=True)
8484
for row in par_to_estimate.to_dict("records")
8585
]
8686
).T

0 commit comments

Comments
 (0)