Skip to content

Commit 0b704ba

Browse files
committed
..
1 parent 6420d67 commit 0b704ba

File tree

4 files changed

+99
-168
lines changed

4 files changed

+99
-168
lines changed

doc/example/distributions.ipynb

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,11 @@
177177
"source": [
178178
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
179179
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
180-
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
181-
"\n"
180+
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
182181
],
183182
"id": "581e1ac431860419",
184183
"outputs": [],
185184
"execution_count": null
186-
},
187-
{
188-
"metadata": {},
189-
"cell_type": "code",
190-
"source": "",
191-
"id": "633733651bbc3ef0",
192-
"outputs": [],
193-
"execution_count": null
194185
}
195186
],
196187
"metadata": {

petab/v1/distributions.py

Lines changed: 86 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,76 @@
44
import abc
55

66
import numpy as np
7-
from scipy.stats import laplace, lognorm, loguniform, norm, uniform
7+
from scipy.stats import laplace, norm, uniform
88

99
__all__ = [
1010
"Distribution",
1111
"Normal",
12-
"LogNormal",
1312
"Uniform",
14-
"LogUniform",
1513
"Laplace",
16-
"LogLaplace",
1714
]
1815

1916

2017
class Distribution(abc.ABC):
2118
"""A univariate probability distribution."""
2219

23-
@abc.abstractmethod
20+
def __init__(self, log: bool | float = False):
21+
if log is True:
22+
log = np.exp(1)
23+
self._log = log
24+
25+
def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
26+
"""Undo the log transformation.
27+
28+
:param x: The sample to transform.
29+
:return: The transformed sample
30+
"""
31+
if self._log is False:
32+
return x
33+
return self._log**x
34+
35+
def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float:
36+
"""Apply the log transformation.
37+
38+
:param x: The value to transform.
39+
:return: The transformed value.
40+
"""
41+
if self._log is False:
42+
return x
43+
return np.log(x) / np.log(self._log)
44+
2445
def sample(self, shape=None) -> np.ndarray:
2546
"""Sample from the distribution.
2647
2748
:param shape: The shape of the sample.
2849
:return: A sample from the distribution.
2950
"""
30-
...
51+
sample = self._sample(shape)
52+
return self._undo_log(sample)
3153

3254
@abc.abstractmethod
55+
def _sample(self, shape=None) -> np.ndarray:
56+
"""Sample from the underlying distribution.
57+
58+
:param shape: The shape of the sample.
59+
:return: A sample from the underlying distribution,
60+
before applying, e.g., the log transformation.
61+
"""
62+
...
63+
3364
def pdf(self, x):
3465
"""Probability density function at x.
3566
67+
:param x: The value at which to evaluate the PDF.
68+
:return: The value of the PDF at ``x``.
69+
"""
70+
chain_rule_factor = (1 / (x * np.log(self._log))) if self._log else 1
71+
return self._pdf(self._apply_log(x)) * chain_rule_factor
72+
73+
@abc.abstractmethod
74+
def _pdf(self, x):
75+
"""Probability density function of the underlying distribution at x.
76+
3677
:param x: The value at which to evaluate the PDF.
3778
:return: The value of the PDF at ``x``.
3879
"""
@@ -44,71 +85,39 @@ class Normal(Distribution):
4485

4586
def __init__(
4687
self,
47-
mean: float,
48-
std: float,
88+
loc: float,
89+
scale: float,
4990
truncation: tuple[float, float] | None = None,
91+
log: bool | float = False,
5092
):
51-
super().__init__()
52-
self._mean = mean
53-
self._std = std
93+
super().__init__(log=log)
94+
self._loc = loc
95+
self._scale = scale
5496
self._truncation = truncation
5597

5698
if truncation is not None:
5799
raise NotImplementedError("Truncation is not yet implemented.")
58100

59101
def __repr__(self):
60-
return (
61-
f"Normal(mean={self._mean}, std={self._std}, "
62-
f"truncation={self._truncation})"
63-
)
64-
65-
def sample(self, shape=None):
66-
return np.random.normal(loc=self._mean, scale=self._std, size=shape)
67-
68-
def pdf(self, x):
69-
return norm.pdf(x, loc=self._mean, scale=self._std)
70-
71-
72-
class LogNormal(Distribution):
73-
"""A log-normal distribution.
74-
75-
:param mean: The mean of the underlying normal distribution.
76-
:param std: The standard deviation of the underlying normal distribution.
77-
78-
"""
79-
80-
def __init__(
81-
self,
82-
mean: float,
83-
std: float,
84-
truncation: tuple[float, float] | None = None,
85-
base: float = np.exp(1),
86-
):
87-
super().__init__()
88-
self._mean = mean
89-
self._std = std
90-
self._truncation = truncation
91-
self._base = base
102+
trunc = f", truncation={self._truncation}" if self._truncation else ""
103+
log = f", log={self._log}" if self._log else ""
104+
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})"
92105

93-
if truncation is not None:
94-
raise NotImplementedError("Truncation is not yet implemented.")
106+
def _sample(self, shape=None):
107+
return np.random.normal(loc=self._loc, scale=self._scale, size=shape)
95108

96-
if base != np.exp(1):
97-
raise NotImplementedError("Only base e is supported.")
109+
def _pdf(self, x):
110+
return norm.pdf(x, loc=self._loc, scale=self._scale)
98111

99-
def __repr__(self):
100-
return (
101-
f"LogNormal(mean={self._mean}, std={self._std}, "
102-
f"base={self._base}, truncation={self._truncation})"
103-
)
104-
105-
def sample(self, shape=None):
106-
return np.random.lognormal(
107-
mean=self._mean, sigma=self._std, size=shape
108-
)
112+
@property
113+
def loc(self):
114+
"""The location parameter of the underlying distribution."""
115+
return self._loc
109116

110-
def pdf(self, x):
111-
return lognorm.pdf(x, scale=np.exp(self._mean), s=self._std)
117+
@property
118+
def scale(self):
119+
"""The scale parameter of the underlying distribution."""
120+
return self._scale
112121

113122

114123
class Uniform(Distribution):
@@ -118,58 +127,24 @@ def __init__(
118127
self,
119128
low: float,
120129
high: float,
130+
*,
131+
log: bool | float = False,
121132
):
122-
super().__init__()
133+
super().__init__(log=log)
123134
self._low = low
124135
self._high = high
125136

126137
def __repr__(self):
127-
return f"Uniform(low={self._low}, high={self._high})"
138+
log = f", log={self._log}" if self._log else ""
139+
return f"Uniform(low={self._low}, high={self._high}{log})"
128140

129-
def sample(self, shape=None):
141+
def _sample(self, shape=None):
130142
return np.random.uniform(low=self._low, high=self._high, size=shape)
131143

132-
def pdf(self, x):
144+
def _pdf(self, x):
133145
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
134146

135147

136-
class LogUniform(Distribution):
137-
"""A log-uniform distribution.
138-
139-
:param low: The lower bound of the underlying normal distribution.
140-
:param high: The upper bound of the underlying normal distribution.
141-
"""
142-
143-
def __init__(
144-
self,
145-
low: float,
146-
high: float,
147-
base: float = np.exp(1),
148-
):
149-
super().__init__()
150-
self._low = low
151-
self._high = high
152-
self._base = base
153-
# re-scaled distribution parameters as required by
154-
# scipy.stats.loguniform
155-
self._low_internal = np.exp(np.log(base) * low)
156-
self._high_internal = np.exp(np.log(base) * high)
157-
158-
def __repr__(self):
159-
return (
160-
f"LogUniform(low={self._low}, high={self._high}, "
161-
f"base={self._base})"
162-
)
163-
164-
def sample(self, shape=None):
165-
return loguniform.rvs(
166-
self._low_internal, self._high_internal, size=shape
167-
)
168-
169-
def pdf(self, x):
170-
return loguniform.pdf(x, self._low_internal, self._high_internal)
171-
172-
173148
class Laplace(Distribution):
174149
"""A Laplace distribution."""
175150

@@ -178,65 +153,32 @@ def __init__(
178153
loc: float,
179154
scale: float,
180155
truncation: tuple[float, float] | None = None,
156+
log: bool | float = False,
181157
):
182-
super().__init__()
158+
super().__init__(log=log)
183159
self._loc = loc
184160
self._scale = scale
185161
self._truncation = truncation
186162
if truncation is not None:
187163
raise NotImplementedError("Truncation is not yet implemented.")
188164

189-
def sample(self, shape=None):
165+
def __repr__(self):
166+
trunc = f", truncation={self._truncation}" if self._truncation else ""
167+
log = f", log={self._log}" if self._log else ""
168+
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"
169+
170+
def _sample(self, shape=None):
190171
return np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
191172

192-
def pdf(self, x):
173+
def _pdf(self, x):
193174
return laplace.pdf(x, loc=self._loc, scale=self._scale)
194175

195-
196-
class LogLaplace(Distribution):
197-
"""A log-Laplace distribution."""
198-
199-
def __init__(
200-
self,
201-
loc: float,
202-
scale: float,
203-
truncation: tuple[float, float] | None = None,
204-
base: float = np.exp(1),
205-
):
206-
super().__init__()
207-
self._loc = loc
208-
self._scale = scale
209-
self._truncation = truncation
210-
self._base = base
211-
if truncation is not None:
212-
raise NotImplementedError("Truncation is not yet implemented.")
213-
if base != np.exp(1):
214-
raise NotImplementedError("Only base e is supported.")
215-
216-
def __repr__(self):
217-
return (
218-
f"LogLaplace(loc={self._loc}, scale={self._scale}, "
219-
f"base={self._base}, truncation={self._truncation})"
220-
)
221-
222176
@property
223177
def loc(self):
224-
"""The mean of the underlying Laplace distribution."""
178+
"""The location parameter of the underlying distribution."""
225179
return self._loc
226180

227181
@property
228182
def scale(self):
229-
"""The scale of the underlying Laplace distribution."""
183+
"""The scale parameter of the underlying distribution."""
230184
return self._scale
231-
232-
def sample(self, shape=None):
233-
return np.exp(
234-
np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
235-
)
236-
237-
def pdf(self, x):
238-
return (
239-
1
240-
/ (2 * self.scale * x)
241-
* np.exp(-np.abs(np.log(x) - self._loc) / self._scale)
242-
)

petab/v1/priors.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,17 @@ def __init__(
9797
case (C.LAPLACE, _) | (C.PARAMETER_SCALE_LAPLACE, C.LIN):
9898
self.distribution = Laplace(*parameters)
9999
case (C.PARAMETER_SCALE_UNIFORM, C.LOG):
100-
self.distribution = LogUniform(*parameters)
100+
self.distribution = Uniform(*parameters, log=True)
101101
case (C.LOG_NORMAL, _) | (C.PARAMETER_SCALE_NORMAL, C.LOG):
102-
self.distribution = LogNormal(*parameters)
102+
self.distribution = Normal(*parameters, log=True)
103103
case (C.LOG_LAPLACE, _) | (C.PARAMETER_SCALE_LAPLACE, C.LOG):
104-
self.distribution = LogLaplace(*parameters)
104+
self.distribution = Laplace(*parameters, log=True)
105105
case (C.PARAMETER_SCALE_UNIFORM, C.LOG10):
106-
self.distribution = LogUniform(*parameters, base=10)
106+
self.distribution = Uniform(*parameters, log=10)
107107
case (C.PARAMETER_SCALE_NORMAL, C.LOG10):
108-
self.distribution = LogNormal(
109-
np.log(10) * parameters[0], np.log(10) * parameters[1]
110-
)
108+
self.distribution = Normal(*parameters, log=10)
111109
case (C.PARAMETER_SCALE_LAPLACE, C.LOG10):
112-
self.distribution = LogLaplace(
113-
np.log(10) * parameters[0], np.log(10) * parameters[1]
114-
)
110+
self.distribution = Laplace(*parameters, log=10)
115111
case _:
116112
raise ValueError(
117113
"Unsupported distribution type / transformation: "

tests/v1/test_distributions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
list(
1515
product(
1616
[
17-
Normal(1, 1),
18-
LogNormal(2, 1),
17+
Normal(2, 1),
18+
Normal(2, 1, log=True),
19+
Normal(2, 1, log=10),
1920
Uniform(2, 4),
20-
LogUniform(1, 2),
21+
Uniform(-2, 4, log=True),
22+
Uniform(2, 4, log=10),
2123
Laplace(1, 2),
22-
LogLaplace(1, 0.5),
24+
Laplace(1, 0.5, log=True),
2325
],
2426
[LIN, LOG, LOG10],
2527
)

0 commit comments

Comments
 (0)