Skip to content

Commit cc6a7e7

Browse files
committed
..
1 parent 6420d67 commit cc6a7e7

File tree

4 files changed

+150
-172
lines changed

4 files changed

+150
-172
lines changed

doc/example/distributions.ipynb

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,11 @@
177177
"source": [
178178
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
179179
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
180-
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
181-
"\n"
180+
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
182181
],
183182
"id": "581e1ac431860419",
184183
"outputs": [],
185184
"execution_count": null
186-
},
187-
{
188-
"metadata": {},
189-
"cell_type": "code",
190-
"source": "",
191-
"id": "633733651bbc3ef0",
192-
"outputs": [],
193-
"execution_count": null
194185
}
195186
],
196187
"metadata": {

petab/v1/distributions.py

Lines changed: 137 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -4,239 +4,228 @@
44
import abc
55

66
import numpy as np
7-
from scipy.stats import laplace, lognorm, loguniform, norm, uniform
7+
from scipy.stats import laplace, norm, uniform
88

99
__all__ = [
1010
"Distribution",
1111
"Normal",
12-
"LogNormal",
1312
"Uniform",
14-
"LogUniform",
1513
"Laplace",
16-
"LogLaplace",
1714
]
1815

1916

2017
class Distribution(abc.ABC):
21-
"""A univariate probability distribution."""
18+
"""A univariate probability distribution.
19+
20+
This class provides a common interface for sampling from and evaluating
21+
the probability density function of a univariate probability distribution.
22+
23+
The distribution can be transformed by applying a logarithm to the samples
24+
and the PDF. This is useful, e.g., for log-normal distributions.
25+
26+
:param log: If ``True``, the distribution is transformed to its
27+
corresponding log distribution (e.g., Normal -> LogNormal).
28+
If a float, the distribution is transformed to its corresponding
29+
log distribution with the given base (e.g., Normal -> Log10Normal).
30+
If ``False``, no transformation is applied.
31+
"""
32+
33+
def __init__(self, log: bool | float = False):
34+
if log is True:
35+
log = np.exp(1)
36+
self._log = log
37+
38+
def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
39+
"""Undo the log transformation.
40+
41+
:param x: The sample to transform.
42+
:return: The transformed sample
43+
"""
44+
if self._log is False:
45+
return x
46+
return self._log**x
47+
48+
def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float:
49+
"""Apply the log transformation.
50+
51+
:param x: The value to transform.
52+
:return: The transformed value.
53+
"""
54+
if self._log is False:
55+
return x
56+
return np.log(x) / np.log(self._log)
2257

23-
@abc.abstractmethod
2458
def sample(self, shape=None) -> np.ndarray:
2559
"""Sample from the distribution.
2660
2761
:param shape: The shape of the sample.
2862
:return: A sample from the distribution.
2963
"""
30-
...
64+
sample = self._sample(shape)
65+
return self._undo_log(sample)
3166

3267
@abc.abstractmethod
68+
def _sample(self, shape=None) -> np.ndarray:
69+
"""Sample from the underlying distribution.
70+
71+
:param shape: The shape of the sample.
72+
:return: A sample from the underlying distribution,
73+
before applying, e.g., the log transformation.
74+
"""
75+
...
76+
3377
def pdf(self, x):
3478
"""Probability density function at x.
3579
80+
:param x: The value at which to evaluate the PDF.
81+
:return: The value of the PDF at ``x``.
82+
"""
83+
chain_rule_factor = (1 / (x * np.log(self._log))) if self._log else 1
84+
return self._pdf(self._apply_log(x)) * chain_rule_factor
85+
86+
@abc.abstractmethod
87+
def _pdf(self, x):
88+
"""Probability density function of the underlying distribution at x.
89+
3690
:param x: The value at which to evaluate the PDF.
3791
:return: The value of the PDF at ``x``.
3892
"""
3993
...
4094

4195

4296
class Normal(Distribution):
43-
"""A normal distribution."""
97+
"""A (log-)normal distribution.
98+
99+
:param loc: The location parameter of the distribution.
100+
:param scale: The scale parameter of the distribution.
101+
:param truncation: The truncation limits of the distribution.
102+
:param log: If ``True``, the distribution is transformed to a log-normal
103+
distribution. If a float, the distribution is transformed to a
104+
log-normal distribution with the given base.
105+
If ``False``, no transformation is applied.
106+
If a transformation is applied, the location and scale parameters
107+
and the truncation limits are the location, scale and truncation limits
108+
of the underlying normal distribution.
109+
"""
44110

45111
def __init__(
46112
self,
47-
mean: float,
48-
std: float,
113+
loc: float,
114+
scale: float,
49115
truncation: tuple[float, float] | None = None,
116+
log: bool | float = False,
50117
):
51-
super().__init__()
52-
self._mean = mean
53-
self._std = std
118+
super().__init__(log=log)
119+
self._loc = loc
120+
self._scale = scale
54121
self._truncation = truncation
55122

56123
if truncation is not None:
57124
raise NotImplementedError("Truncation is not yet implemented.")
58125

59126
def __repr__(self):
60-
return (
61-
f"Normal(mean={self._mean}, std={self._std}, "
62-
f"truncation={self._truncation})"
63-
)
64-
65-
def sample(self, shape=None):
66-
return np.random.normal(loc=self._mean, scale=self._std, size=shape)
67-
68-
def pdf(self, x):
69-
return norm.pdf(x, loc=self._mean, scale=self._std)
70-
71-
72-
class LogNormal(Distribution):
73-
"""A log-normal distribution.
74-
75-
:param mean: The mean of the underlying normal distribution.
76-
:param std: The standard deviation of the underlying normal distribution.
77-
78-
"""
79-
80-
def __init__(
81-
self,
82-
mean: float,
83-
std: float,
84-
truncation: tuple[float, float] | None = None,
85-
base: float = np.exp(1),
86-
):
87-
super().__init__()
88-
self._mean = mean
89-
self._std = std
90-
self._truncation = truncation
91-
self._base = base
92-
93-
if truncation is not None:
94-
raise NotImplementedError("Truncation is not yet implemented.")
127+
trunc = f", truncation={self._truncation}" if self._truncation else ""
128+
log = f", log={self._log}" if self._log else ""
129+
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})"
95130

96-
if base != np.exp(1):
97-
raise NotImplementedError("Only base e is supported.")
131+
def _sample(self, shape=None):
132+
return np.random.normal(loc=self._loc, scale=self._scale, size=shape)
98133

99-
def __repr__(self):
100-
return (
101-
f"LogNormal(mean={self._mean}, std={self._std}, "
102-
f"base={self._base}, truncation={self._truncation})"
103-
)
134+
def _pdf(self, x):
135+
return norm.pdf(x, loc=self._loc, scale=self._scale)
104136

105-
def sample(self, shape=None):
106-
return np.random.lognormal(
107-
mean=self._mean, sigma=self._std, size=shape
108-
)
137+
@property
138+
def loc(self):
139+
"""The location parameter of the underlying distribution."""
140+
return self._loc
109141

110-
def pdf(self, x):
111-
return lognorm.pdf(x, scale=np.exp(self._mean), s=self._std)
142+
@property
143+
def scale(self):
144+
"""The scale parameter of the underlying distribution."""
145+
return self._scale
112146

113147

114148
class Uniform(Distribution):
115-
"""A uniform distribution."""
149+
"""A (log-)uniform distribution.
150+
151+
:param low: The lower bound of the distribution.
152+
:param high: The upper bound of the distribution.
153+
:param log: If ``True``, the distribution is transformed to a log-uniform
154+
distribution. If a float, the distribution is transformed to a
155+
log-uniform distribution with the given base.
156+
If ``False``, no transformation is applied.
157+
If a transformation is applied, the lower and upper bounds are the
158+
lower and upper bounds of the underlying uniform distribution.
159+
"""
116160

117161
def __init__(
118162
self,
119163
low: float,
120164
high: float,
165+
*,
166+
log: bool | float = False,
121167
):
122-
super().__init__()
168+
super().__init__(log=log)
123169
self._low = low
124170
self._high = high
125171

126172
def __repr__(self):
127-
return f"Uniform(low={self._low}, high={self._high})"
173+
log = f", log={self._log}" if self._log else ""
174+
return f"Uniform(low={self._low}, high={self._high}{log})"
128175

129-
def sample(self, shape=None):
176+
def _sample(self, shape=None):
130177
return np.random.uniform(low=self._low, high=self._high, size=shape)
131178

132-
def pdf(self, x):
179+
def _pdf(self, x):
133180
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
134181

135182

136-
class LogUniform(Distribution):
137-
"""A log-uniform distribution.
138-
139-
:param low: The lower bound of the underlying normal distribution.
140-
:param high: The upper bound of the underlying normal distribution.
141-
"""
142-
143-
def __init__(
144-
self,
145-
low: float,
146-
high: float,
147-
base: float = np.exp(1),
148-
):
149-
super().__init__()
150-
self._low = low
151-
self._high = high
152-
self._base = base
153-
# re-scaled distribution parameters as required by
154-
# scipy.stats.loguniform
155-
self._low_internal = np.exp(np.log(base) * low)
156-
self._high_internal = np.exp(np.log(base) * high)
157-
158-
def __repr__(self):
159-
return (
160-
f"LogUniform(low={self._low}, high={self._high}, "
161-
f"base={self._base})"
162-
)
163-
164-
def sample(self, shape=None):
165-
return loguniform.rvs(
166-
self._low_internal, self._high_internal, size=shape
167-
)
168-
169-
def pdf(self, x):
170-
return loguniform.pdf(x, self._low_internal, self._high_internal)
171-
172-
173183
class Laplace(Distribution):
174-
"""A Laplace distribution."""
184+
"""A (log-)Laplace distribution.
185+
186+
:param loc: The location parameter of the distribution.
187+
:param scale: The scale parameter of the distribution.
188+
:param truncation: The truncation limits of the distribution.
189+
:param log: If ``True``, the distribution is transformed to a log-Laplace
190+
distribution. If a float, the distribution is transformed to a
191+
log-Laplace distribution with the given base.
192+
If ``False``, no transformation is applied.
193+
If a transformation is applied, the location and scale parameters
194+
and the truncation limits are the location, scale and truncation limits
195+
of the underlying Laplace distribution.
196+
"""
175197

176198
def __init__(
177199
self,
178200
loc: float,
179201
scale: float,
180202
truncation: tuple[float, float] | None = None,
203+
log: bool | float = False,
181204
):
182-
super().__init__()
205+
super().__init__(log=log)
183206
self._loc = loc
184207
self._scale = scale
185208
self._truncation = truncation
186209
if truncation is not None:
187210
raise NotImplementedError("Truncation is not yet implemented.")
188211

189-
def sample(self, shape=None):
212+
def __repr__(self):
213+
trunc = f", truncation={self._truncation}" if self._truncation else ""
214+
log = f", log={self._log}" if self._log else ""
215+
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"
216+
217+
def _sample(self, shape=None):
190218
return np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
191219

192-
def pdf(self, x):
220+
def _pdf(self, x):
193221
return laplace.pdf(x, loc=self._loc, scale=self._scale)
194222

195-
196-
class LogLaplace(Distribution):
197-
"""A log-Laplace distribution."""
198-
199-
def __init__(
200-
self,
201-
loc: float,
202-
scale: float,
203-
truncation: tuple[float, float] | None = None,
204-
base: float = np.exp(1),
205-
):
206-
super().__init__()
207-
self._loc = loc
208-
self._scale = scale
209-
self._truncation = truncation
210-
self._base = base
211-
if truncation is not None:
212-
raise NotImplementedError("Truncation is not yet implemented.")
213-
if base != np.exp(1):
214-
raise NotImplementedError("Only base e is supported.")
215-
216-
def __repr__(self):
217-
return (
218-
f"LogLaplace(loc={self._loc}, scale={self._scale}, "
219-
f"base={self._base}, truncation={self._truncation})"
220-
)
221-
222223
@property
223224
def loc(self):
224-
"""The mean of the underlying Laplace distribution."""
225+
"""The location parameter of the underlying distribution."""
225226
return self._loc
226227

227228
@property
228229
def scale(self):
229-
"""The scale of the underlying Laplace distribution."""
230+
"""The scale parameter of the underlying distribution."""
230231
return self._scale
231-
232-
def sample(self, shape=None):
233-
return np.exp(
234-
np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
235-
)
236-
237-
def pdf(self, x):
238-
return (
239-
1
240-
/ (2 * self.scale * x)
241-
* np.exp(-np.abs(np.log(x) - self._loc) / self._scale)
242-
)

0 commit comments

Comments
 (0)