Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2fd6da8
Implement proper truncation for prior distributions
dweindl Dec 11, 2024
65ef80f
optional
dweindl Dec 11, 2024
c01f2fb
fix cdf normalization
dweindl Dec 11, 2024
d3b4e7f
review
dweindl Dec 11, 2024
057457f
Merge branch 'develop' into 330_truncated
dweindl Dec 11, 2024
1425d9c
fix cdf/pdf outside bounds / <0
dweindl Dec 11, 2024
155853f
Always sample correctly, but optionally use unscaled pdf for neglogprior
dweindl Dec 11, 2024
2484a7f
prior always on linear
dweindl Dec 12, 2024
a17aa62
Fix Prior.from_par_dict for missing priorParameters columns (#341)
dweindl Dec 12, 2024
6f005b8
Merge branch 'develop' into 330_truncated
dweindl Dec 18, 2024
a5d2a3d
Update doc/example/distributions.ipynb
dweindl Feb 11, 2025
2529bf9
Merge branch 'develop' into 330_truncated
dweindl Feb 11, 2025
7ae0f40
tuncation/transformation
dweindl Mar 10, 2025
b762237
Merge branch 'develop' into 330_truncated
dweindl Mar 11, 2025
51367db
reruff
dweindl Mar 28, 2025
9e65449
Merge branch 'develop' into 330_truncated
dweindl Mar 28, 2025
f5278d3
Revert "tuncation/transformation"
dweindl Apr 23, 2025
728b4d6
review
dweindl Apr 23, 2025
4387443
Merge branch 'develop' into 330_truncated
dweindl Apr 23, 2025
e3d2eba
review
dweindl Apr 23, 2025
3775bb4
_bounds_truncate
dweindl Apr 23, 2025
48c9bdf
Update tests/v1/test_priors.py
dweindl Apr 24, 2025
4d19fc3
..
dweindl Apr 24, 2025
d2d9202
..
dweindl Apr 24, 2025
b0b5e7e
Merge branch 'develop' into 330_truncated
dweindl Apr 24, 2025
a64083e
..
dweindl Apr 24, 2025
330f902
trunc scale
dweindl Apr 24, 2025
5581a6e
pdf nan outside domain
dweindl Apr 24, 2025
4db0fe6
_pdf_untruncated
dweindl Apr 24, 2025
0a63f99
Merge branch 'develop' into 330_truncated
dweindl Apr 24, 2025
9ad05c1
prettify repr
dweindl Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
nb_execution_mode = "force"
nb_execution_raise_on_error = True
nb_execution_show_tb = True
nb_execution_timeout = 90 # max. seconds/cell

source_suffix = {
".rst": "restructuredtext",
Expand Down
28 changes: 19 additions & 9 deletions doc/example/distributions.ipynb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to some v1 subfolder? Now or later is fine. But I think priors will change a lot in v2

Copy link
Member Author

@dweindl dweindl Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about moving it to https://github.com/PEtab-dev/PEtab/ at some point. It might also be helpful for non-python petab users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
},
{
"metadata": {
"collapsed": true
"collapsed": true,
"jupyter": {
"is_executing": true
}
},
"cell_type": "code",
"source": [
Expand All @@ -42,7 +45,7 @@
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
"\n",
" sample = prior.sample(10000)\n",
" sample = prior.sample(20_000)\n",
"\n",
" # pdf\n",
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n",
Expand Down Expand Up @@ -138,11 +141,13 @@
"metadata": {},
"cell_type": "code",
"source": [
"# different, because transformation!=LIN\n",
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n",
"\n",
"# same, because transformation=LIN\n",
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n"
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))"
],
"id": "5ca940bc24312fc6",
"outputs": [],
Expand All @@ -151,15 +156,18 @@
{
"metadata": {},
"cell_type": "markdown",
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):",
"source": "The given distributions are truncated at the bounds defined in the parameter table:",
"id": "b1a8b17d765db826"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias"
"plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n",
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n",
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n",
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))\n"
],
"id": "4ac42b1eed759bdd",
"outputs": [],
Expand All @@ -175,9 +183,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**9, 10**14), transformation=\"log10\"))\n",
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n",
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))"
],
"id": "581e1ac431860419",
"outputs": [],
Expand Down
181 changes: 157 additions & 24 deletions petab/v1/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,65 @@ class Distribution(abc.ABC):
If a float, the distribution is transformed to its corresponding
log distribution with the given base (e.g., Normal -> Log10Normal).
If ``False``, no transformation is applied.
:param trunc: The truncation points (lower, upper) of the distribution
or ``None`` if the distribution is not truncated.
"""

def __init__(self, log: bool | float = False):
def __init__(
self, *, log: bool | float = False, trunc: tuple[float, float] = None
):
if log is True:
log = np.exp(1)

if trunc == (-np.inf, np.inf):
trunc = None

if trunc is not None and trunc[0] > trunc[1]:
raise ValueError(
"The lower truncation limit must be smaller "
"than the upper truncation limit."
)

self._logbase = log
self._trunc = trunc

self._cd_low = None
self._cd_high = None
self._truncation_normalizer = 1

if self._trunc is not None:
try:
# the cumulative density of the transformed distribution at the
# truncation limits
self._cd_low = self._cdf_transformed_untruncated(
self.trunc_low
)
self._cd_high = self._cdf_transformed_untruncated(
self.trunc_high
)
# normalization factor for the PDF of the transformed
# distribution to account for truncation
self._truncation_normalizer = 1 / (
self._cd_high - self._cd_low
)
except NotImplementedError:
pass

@property
def trunc_low(self) -> float:
"""The lower truncation limit of the transformed distribution."""
return self._trunc[0] if self._trunc else -np.inf

@property
def trunc_high(self) -> float:
"""The upper truncation limit of the transformed distribution."""
return self._trunc[1] if self._trunc else np.inf

def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
"""Undo the log transformation.
def _exp(self, x: np.ndarray | float) -> np.ndarray | float:
"""Exponentiate / undo the log transformation according.

Exponentiate if a log transformation is applied to the distribution.
Otherwise, return the input.

:param x: The sample to transform.
:return: The transformed sample
Expand All @@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
return x
return self._logbase**x

def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float:
def _log(self, x: np.ndarray | float) -> np.ndarray | float:
"""Apply the log transformation.

Compute the log of x with the specified base if a log transformation
is applied to the distribution. Otherwise, return the input.

:param x: The value to transform.
:return: The transformed value.
"""
Expand All @@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray:
:param shape: The shape of the sample.
:return: A sample from the distribution.
"""
sample = self._sample(shape)
return self._undo_log(sample)
sample = (
self._exp(self._sample(shape))
if self._trunc is None
else self._inverse_transform_sample(shape)
)

return sample

@abc.abstractmethod
def _sample(self, shape=None) -> np.ndarray:
"""Sample from the underlying distribution.
"""Sample from the underlying distribution, accounting for truncation.

:param shape: The shape of the sample.
:return: A sample from the underlying distribution,
Expand All @@ -85,7 +143,11 @@ def pdf(self, x):
chain_rule_factor = (
(1 / (x * np.log(self._logbase))) if self._logbase else 1
)
return self._pdf(self._apply_log(x)) * chain_rule_factor
return (
self._pdf(self._log(x))
* chain_rule_factor
* self._truncation_normalizer
)

@abc.abstractmethod
def _pdf(self, x):
Expand All @@ -104,13 +166,71 @@ def logbase(self) -> bool | float:
"""
return self._logbase

def cdf(self, x):
"""Cumulative distribution function at x.

:param x: The value at which to evaluate the CDF.
:return: The value of the CDF at ``x``.
"""
return self._cdf_transformed_untruncated(x) - self._cd_low

def _cdf_transformed_untruncated(self, x):
"""Cumulative distribution function of the transformed, but untruncated
distribution at x.

:param x: The value at which to evaluate the CDF.
:return: The value of the CDF at ``x``.
"""
return self._cdf_untransformed_untruncated(self._log(x))

def _cdf_untransformed_untruncated(self, x):
"""Cumulative distribution function of the underlying
(untransformed, untruncated) distribution at x.

:param x: The value at which to evaluate the CDF.
:return: The value of the CDF at ``x``.
"""
raise NotImplementedError

def _ppf_untransformed_untruncated(self, q):
"""Percent point function of the underlying
(untransformed, untruncated) distribution at q.

:param q: The quantile at which to evaluate the PPF.
:return: The value of the PPF at ``q``.
"""
raise NotImplementedError

def _ppf_transformed_untruncated(self, q):
"""Percent point function of the transformed, but untruncated
distribution at q.

:param q: The quantile at which to evaluate the PPF.
:return: The value of the PPF at ``q``.
"""
return self._exp(self._ppf_untransformed_untruncated(q))

def _inverse_transform_sample(self, shape):
"""Generate an inverse transform sample from the transformed and
truncated distribution.

:param shape: The shape of the sample.
:return: The sample.
"""
uniform_sample = np.random.uniform(
low=self._cd_low, high=self._cd_high, size=shape
)
return self._ppf_transformed_untruncated(uniform_sample)


class Normal(Distribution):
"""A (log-)normal distribution.

:param loc: The location parameter of the distribution.
:param scale: The scale parameter of the distribution.
:param truncation: The truncation limits of the distribution.
:param trunc: The truncation limits of the distribution.
``None`` if the distribution is not truncated. The truncation limits
are the truncation limits of the transformed distribution.
:param log: If ``True``, the distribution is transformed to a log-normal
distribution. If a float, the distribution is transformed to a
log-normal distribution with the given base.
Expand All @@ -124,19 +244,15 @@ def __init__(
self,
loc: float,
scale: float,
truncation: tuple[float, float] | None = None,
trunc: tuple[float, float] | None = None,
log: bool | float = False,
):
super().__init__(log=log)
self._loc = loc
self._scale = scale
self._truncation = truncation

if truncation is not None:
raise NotImplementedError("Truncation is not yet implemented.")
super().__init__(log=log, trunc=trunc)

def __repr__(self):
trunc = f", truncation={self._truncation}" if self._truncation else ""
trunc = f", trunc={self._trunc}" if self._trunc else ""
log = f", log={self._logbase}" if self._logbase else ""
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})"

Expand All @@ -146,6 +262,12 @@ def _sample(self, shape=None):
def _pdf(self, x):
return norm.pdf(x, loc=self._loc, scale=self._scale)

def _cdf_untransformed_untruncated(self, x):
return norm.cdf(x, loc=self._loc, scale=self._scale)

def _ppf_untransformed_untruncated(self, q):
return norm.ppf(q, loc=self._loc, scale=self._scale)

@property
def loc(self):
"""The location parameter of the underlying distribution."""
Expand Down Expand Up @@ -177,9 +299,9 @@ def __init__(
*,
log: bool | float = False,
):
super().__init__(log=log)
self._low = low
self._high = high
super().__init__(log=log)

def __repr__(self):
log = f", log={self._logbase}" if self._logbase else ""
Expand All @@ -191,13 +313,21 @@ def _sample(self, shape=None):
def _pdf(self, x):
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)

def _cdf_untransformed_untruncated(self, x):
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)

def _ppf_untransformed_untruncated(self, q):
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)


class Laplace(Distribution):
"""A (log-)Laplace distribution.

:param loc: The location parameter of the distribution.
:param scale: The scale parameter of the distribution.
:param truncation: The truncation limits of the distribution.
:param trunc: The truncation limits of the distribution.
``None`` if the distribution is not truncated. The truncation limits
are the truncation limits of the transformed distribution.
:param log: If ``True``, the distribution is transformed to a log-Laplace
distribution. If a float, the distribution is transformed to a
log-Laplace distribution with the given base.
Expand All @@ -211,18 +341,15 @@ def __init__(
self,
loc: float,
scale: float,
truncation: tuple[float, float] | None = None,
trunc: tuple[float, float] | None = None,
log: bool | float = False,
):
super().__init__(log=log)
self._loc = loc
self._scale = scale
self._truncation = truncation
if truncation is not None:
raise NotImplementedError("Truncation is not yet implemented.")
super().__init__(log=log, trunc=trunc)

def __repr__(self):
trunc = f", truncation={self._truncation}" if self._truncation else ""
trunc = f", trunc={self._trunc}" if self._trunc else ""
log = f", log={self._logbase}" if self._logbase else ""
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"

Expand All @@ -232,6 +359,12 @@ def _sample(self, shape=None):
def _pdf(self, x):
return laplace.pdf(x, loc=self._loc, scale=self._scale)

def _cdf_untransformed_untruncated(self, x):
return laplace.cdf(x, loc=self._loc, scale=self._scale)

def _ppf_untransformed_untruncated(self, q):
return laplace.ppf(q, loc=self._loc, scale=self._scale)

@property
def loc(self):
"""The location parameter of the underlying distribution."""
Expand Down
Loading
Loading