Skip to content

Commit 7ecc231

Browse files
committed
apply suggestions
1 parent 047dc46 commit 7ecc231

File tree

2 files changed

+77
-31
lines changed

2 files changed

+77
-31
lines changed

petab/v1/distributions.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,27 @@ class Distribution(abc.ABC):
3333
def __init__(self, log: bool | float = False):
3434
if log is True:
3535
log = np.exp(1)
36-
self._log = log
36+
self._logbase = log
3737

3838
def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
3939
"""Undo the log transformation.
4040
4141
:param x: The sample to transform.
4242
:return: The transformed sample
4343
"""
44-
if self._log is False:
44+
if self._logbase is False:
4545
return x
46-
return self._log**x
46+
return self._logbase**x
4747

4848
def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float:
4949
"""Apply the log transformation.
5050
5151
:param x: The value to transform.
5252
:return: The transformed value.
5353
"""
54-
if self._log is False:
54+
if self._logbase is False:
5555
return x
56-
return np.log(x) / np.log(self._log)
56+
return np.log(x) / np.log(self._logbase)
5757

5858
def sample(self, shape=None) -> np.ndarray:
5959
"""Sample from the distribution.
@@ -82,7 +82,9 @@ def pdf(self, x):
8282
"""
8383
# handle the log transformation; see also:
8484
# https://en.wikipedia.org/wiki/Probability_density_function#Scalar_to_scalar
85-
chain_rule_factor = (1 / (x * np.log(self._log))) if self._log else 1
85+
chain_rule_factor = (
86+
(1 / (x * np.log(self._logbase))) if self._logbase else 1
87+
)
8688
return self._pdf(self._apply_log(x)) * chain_rule_factor
8789

8890
@abc.abstractmethod
@@ -94,6 +96,14 @@ def _pdf(self, x):
9496
"""
9597
...
9698

99+
@property
100+
def logbase(self) -> bool | float:
101+
"""The base of the log transformation.
102+
103+
If ``False``, no transformation is applied.
104+
"""
105+
return self._logbase
106+
97107

98108
class Normal(Distribution):
99109
"""A (log-)normal distribution.
@@ -127,7 +137,7 @@ def __init__(
127137

128138
def __repr__(self):
129139
trunc = f", truncation={self._truncation}" if self._truncation else ""
130-
log = f", log={self._log}" if self._log else ""
140+
log = f", log={self._logbase}" if self._logbase else ""
131141
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})"
132142

133143
def _sample(self, shape=None):
@@ -172,7 +182,7 @@ def __init__(
172182
self._high = high
173183

174184
def __repr__(self):
175-
log = f", log={self._log}" if self._log else ""
185+
log = f", log={self._logbase}" if self._logbase else ""
176186
return f"Uniform(low={self._low}, high={self._high}{log})"
177187

178188
def _sample(self, shape=None):
@@ -213,7 +223,7 @@ def __init__(
213223

214224
def __repr__(self):
215225
trunc = f", truncation={self._truncation}" if self._truncation else ""
216-
log = f", log={self._log}" if self._log else ""
226+
log = f", log={self._logbase}" if self._logbase else ""
217227
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"
218228

219229
def _sample(self, shape=None):

tests/v1/test_distributions.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,42 @@
1-
from itertools import product
2-
31
import numpy as np
42
import pytest
3+
from numpy.testing import assert_allclose
54
from scipy.integrate import cumulative_trapezoid
6-
from scipy.stats import kstest
5+
from scipy.stats import (
6+
kstest,
7+
laplace,
8+
loglaplace,
9+
lognorm,
10+
loguniform,
11+
norm,
12+
uniform,
13+
)
714

815
from petab.v1.distributions import *
916
from petab.v2.C import *
1017

1118

1219
@pytest.mark.parametrize(
13-
"distribution, transform",
14-
list(
15-
product(
16-
[
17-
Normal(2, 1),
18-
Normal(2, 1, log=True),
19-
Normal(2, 1, log=10),
20-
Uniform(2, 4),
21-
Uniform(-2, 4, log=True),
22-
Uniform(2, 4, log=10),
23-
Laplace(1, 2),
24-
Laplace(1, 0.5, log=True),
25-
],
26-
[LIN, LOG, LOG10],
27-
)
28-
),
20+
"distribution",
21+
[
22+
Normal(2, 1),
23+
Normal(2, 1, log=True),
24+
Normal(2, 1, log=10),
25+
Uniform(2, 4),
26+
Uniform(-2, 4, log=True),
27+
Uniform(2, 4, log=10),
28+
Laplace(1, 2),
29+
Laplace(1, 0.5, log=True),
30+
],
2931
)
30-
def test_sample_matches_pdf(distribution, transform):
32+
def test_sample_matches_pdf(distribution):
3133
"""Test that the sample matches the PDF."""
3234
np.random.seed(1)
3335
N_SAMPLES = 10_000
34-
distribution.transform = transform
3536
sample = distribution.sample(N_SAMPLES)
3637

37-
# pdf -> cdf
3838
def cdf(x):
39+
# pdf -> cdf
3940
return cumulative_trapezoid(distribution.pdf(x), x)
4041

4142
# Kolmogorov-Smirnov test to check if the sample is drawn from the CDF
@@ -49,3 +50,38 @@ def cdf(x):
4950
# plt.show()
5051

5152
assert p > 0.05, (p, distribution)
53+
54+
# Test samples match scipy CDFs
55+
reference_pdf = None
56+
if isinstance(distribution, Normal) and distribution.logbase is False:
57+
reference_pdf = norm.pdf(sample, distribution.loc, distribution.scale)
58+
elif isinstance(distribution, Uniform) and distribution.logbase is False:
59+
reference_pdf = uniform.pdf(
60+
sample, distribution._low, distribution._high - distribution._low
61+
)
62+
elif isinstance(distribution, Laplace) and distribution.logbase is False:
63+
reference_pdf = laplace.pdf(
64+
sample, distribution.loc, distribution.scale
65+
)
66+
elif isinstance(distribution, Normal) and distribution.logbase == np.exp(
67+
1
68+
):
69+
reference_pdf = lognorm.pdf(
70+
sample, scale=np.exp(distribution.loc), s=distribution.scale
71+
)
72+
elif isinstance(distribution, Uniform) and distribution.logbase == np.exp(
73+
1
74+
):
75+
reference_pdf = loguniform.pdf(
76+
sample, np.exp(distribution._low), np.exp(distribution._high)
77+
)
78+
elif isinstance(distribution, Laplace) and distribution.logbase == np.exp(
79+
1
80+
):
81+
reference_pdf = loglaplace.pdf(
82+
sample, c=1 / distribution.scale, scale=np.exp(distribution.loc)
83+
)
84+
if reference_pdf is not None:
85+
assert_allclose(
86+
distribution.pdf(sample), reference_pdf, rtol=1e-10, atol=1e-14
87+
)

0 commit comments

Comments
 (0)