Skip to content

Commit 2586b17

Browse files
Enable reparametrization of SciPy distributions (#56)
* Make fit an instance method and deprecate fit_instance * Update documentation * Make params_names an instance-level property * Feedback from copilot * Enable reparametrization of SciPy distributions * Self is not available in 3.9 * Move reparametrization logic to ScipyDistribution * Fix tests * Fix module on wrapped distributions * Revert unnecessary change
1 parent 0d912cb commit 2586b17

File tree

4 files changed

+175
-63
lines changed

4 files changed

+175
-63
lines changed

pykelihood/distributions/base.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from collections.abc import Sequence
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
7+
from typing import TYPE_CHECKING, Callable, Generic, Protocol, TypeVar
88

99
import numpy as np
1010
from numpy.typing import ArrayLike
@@ -129,7 +129,7 @@ def fit(
129129
The result of the fit. A new instance is created with the fitted parameters.
130130
"""
131131
init_parms = self._process_fit_params(**fixed_values)
132-
init = type(self)(**init_parms)
132+
init = self.with_params(**init_parms)
133133
data = init._apply_constraints(data)
134134

135135
if x0 is None:
@@ -236,16 +236,91 @@ def __getattr__(self, item: str):
236236
return getattr(self.fitted, item)
237237

238238

239+
class Reparametrization(Protocol):
240+
"""The blueprint for reparametrization functions.
241+
242+
A reparametrization function turns the parameters of the reparametrized
243+
distribution into the parameters of the base distribution.
244+
245+
The provided parameters have already been evaluated. Parameters may be
246+
renamed, modified, removed, or added to the returned dictionary to match
247+
the expected parameters of the base distribution.
248+
"""
249+
250+
def __call__(self, params: dict[str, ArrayLike]) -> dict[str, ArrayLike]: ...
251+
252+
253+
def no_reparametrization(params: dict[str, ArrayLike]) -> dict[str, ArrayLike]:
254+
"""
255+
A no-op reparametrization function that returns the parameters unchanged.
256+
This is useful when no reparametrization is needed.
257+
"""
258+
return params
259+
260+
261+
def _extract_scipy_shape_params_names(
262+
scipy_dist: stats.rv_continuous,
263+
) -> tuple[str, ...]:
264+
return tuple(scipy_dist.shapes.split(", ") if scipy_dist.shapes else ())
265+
266+
239267
class ScipyDistribution(Distribution, ABC):
240268
"""
241269
Base class for distributions based on SciPy.
242270
"""
243271

244272
_base_module: stats.rv_continuous
245273

246-
@abstractmethod
247-
def _to_scipy_args(self, **params) -> dict[str, ArrayLike]:
248-
raise NotImplementedError
274+
def __init__(
275+
self, *args, reparametrization: Reparametrization | None = None, **params
276+
):
277+
if args and reparametrization:
278+
raise ValueError(
279+
"Cannot use both positional parameters and reparametrization."
280+
)
281+
self.reparametrization = reparametrization or no_reparametrization
282+
if reparametrization is None:
283+
self._params_names = ("loc", "scale") + _extract_scipy_shape_params_names(
284+
self._base_module
285+
)
286+
287+
# Insert the positional arguments into params
288+
for arg, name in zip(args, self._params_names):
289+
if name not in params:
290+
params[name] = arg
291+
else:
292+
raise ValueError(
293+
f"Parameter `{name}` passed as positional and keyword argument when initializing {type(self).__name__} distribution."
294+
)
295+
296+
# Set default values for loc and scale if not provided
297+
params["loc"] = params.get("loc", 0.0)
298+
params["scale"] = params.get("scale", 1.0)
299+
300+
# Ensure all shape parameters are present
301+
for arg in self._params_names[len(args) :]:
302+
if arg not in params:
303+
raise ValueError(
304+
f"Missing shape parameter `{arg}` when initializing {type(self).__name__} distribution."
305+
)
306+
307+
# Reorder params to match the order of _params_names
308+
params = {a: params[a] for a in self._params_names}
309+
else:
310+
self._params_names = tuple(params)
311+
super().__init__(*params.values())
312+
313+
def _build_instance(self, **params) -> Self:
314+
return type(self)(reparametrization=self.reparametrization, **params)
315+
316+
@property
317+
def params_names(self) -> tuple[str, ...]:
318+
"""Return the names of the parameters."""
319+
return self._params_names
320+
321+
def _to_scipy_args(self, **kwargs):
322+
values = {k: kwargs.get(k, getattr(self, k)()) for k in self.params_names}
323+
return self.reparametrization(values)
249324

250325
def rvs(self, size=None, random_state=None, **kwargs):
251326
return self._base_module.rvs(

pykelihood/distributions/custom.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,20 @@ class GEV(ScipyDistribution):
231231

232232
_base_module = _stats.genextreme
233233

234+
def _reparametrization(self, params):
235+
return {
236+
"c": -params["shape"],
237+
"loc": params["loc"],
238+
"scale": params["scale"],
239+
}
240+
234241
def __init__(self, loc=0.0, scale=1.0, shape=0.0):
235-
super().__init__(loc, scale, shape)
242+
super().__init__(
243+
loc=loc, scale=scale, shape=shape, reparametrization=self._reparametrization
244+
)
245+
246+
def _build_instance(self, **params):
247+
return type(self)(**params)
236248

237249
@property
238250
def params_names(self):
@@ -284,32 +296,6 @@ def ub_shape(self, data):
284296
else:
285297
return self.scale / (x_max - self.loc())
286298

287-
def _to_scipy_args(self, loc=None, scale=None, shape=None):
288-
"""
289-
Convert to scipy arguments.
290-
291-
Parameters
292-
----------
293-
loc : float, optional
294-
Location parameter, by default None.
295-
scale : float, optional
296-
Scale parameter, by default None.
297-
shape : float, optional
298-
Shape parameter, by default None.
299-
300-
Returns
301-
-------
302-
dict
303-
Dictionary of scipy arguments.
304-
"""
305-
if shape is not None:
306-
shape = -shape
307-
return {
308-
"c": ifnone(shape, -self.shape()),
309-
"loc": ifnone(loc, self.loc()),
310-
"scale": ifnone(scale, self.scale()),
311-
}
312-
313299

314300
class GPD(ScipyDistribution):
315301
"""

pykelihood/distributions/scipy.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import scipy.special
3+
import scipy
44
from packaging.version import Version
55
from scipy import stats
66

@@ -13,9 +13,7 @@ def _name_from_scipy_dist(scipy_dist: stats.rv_continuous) -> str:
1313
return "".join(map(str.capitalize, scipy_dist_name.split("_")))
1414

1515

16-
def wrap_scipy_distribution(
17-
scipy_dist: stats.rv_continuous,
18-
) -> type[ScipyDistribution]:
16+
def wrap_scipy_distribution(scipy_dist: stats.rv_continuous) -> type[ScipyDistribution]:
1917
"""Wrap a scipy distribution class to create a ScipyDistribution subclass."""
2018
scipy_dist_name = type(scipy_dist).__name__.removesuffix("_gen")
2119
clean_dist_name = _name_from_scipy_dist(scipy_dist)
@@ -43,33 +41,15 @@ def format_param_docstring(param: str) -> str:
4341
for param in dist_params_names[2:]:
4442
docstring += format_param_docstring(param)
4543

46-
class Wrapper(ScipyDistribution):
47-
_base_module = scipy_dist
48-
__doc__ = docstring
49-
50-
def __init__(self, loc=0.0, scale=1.0, **kwargs):
51-
self._params_names = dist_params_names
52-
assert self._params_names[:2] == ("loc", "scale")
53-
shape_args = self._params_names[2:]
54-
for arg in shape_args:
55-
if arg not in kwargs:
56-
raise ValueError(
57-
f"Missing shape parameter `{arg}` when initializing {type(self).__name__} distribution."
58-
)
59-
args = [kwargs[a] for a in shape_args]
60-
super().__init__(loc, scale, *args)
61-
62-
@property
63-
def params_names(self) -> tuple[str, ...]:
64-
"""Return the names of the parameters."""
65-
return self._params_names
66-
67-
def _to_scipy_args(self, **kwargs):
68-
return {k: kwargs.get(k, getattr(self, k)()) for k in self.params_names}
69-
70-
Wrapper.__name__ = clean_dist_name
71-
Wrapper.__qualname__ = f"{Wrapper.__module__}.{Wrapper.__name__}"
72-
return Wrapper
44+
return type(
45+
clean_dist_name,
46+
(ScipyDistribution,),
47+
{
48+
"_base_module": scipy_dist,
49+
"__doc__": docstring,
50+
"__module__": wrap_scipy_distribution.__module__,
51+
},
52+
)
7353

7454

7555
Alpha = wrap_scipy_distribution(stats.alpha)

tests/test_parametrization.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import numpy.testing
3+
from numpy.typing import ArrayLike
4+
5+
from pykelihood import kernels
6+
from pykelihood.distributions import Uniform
7+
from pykelihood.metrics import log_likelihood
8+
9+
10+
def uniform_bounds(params: dict[str, ArrayLike]) -> dict[str, float]:
11+
return {"loc": params["a"], "scale": params["b"] - params["a"]}
12+
13+
14+
def test_reparametrization_names():
15+
u = Uniform(a=0, b=1, reparametrization=uniform_bounds)
16+
assert u.params_names == ("a", "b")
17+
18+
19+
def test_reparametrization_rvs():
20+
u = Uniform(a=3, b=5, reparametrization=uniform_bounds)
21+
sample = u.rvs(10000)
22+
numpy.testing.assert_array_less(sample, 5)
23+
numpy.testing.assert_array_less(3, sample)
24+
25+
26+
def test_reparametrization_pdf():
27+
u = Uniform(a=3, b=5, reparametrization=uniform_bounds)
28+
u_standard = Uniform(loc=3, scale=2)
29+
x = np.linspace(2, 6, 100)
30+
numpy.testing.assert_array_almost_equal(u.pdf(x), u_standard.pdf(x))
31+
32+
33+
def test_reparametrization_kernel():
34+
x = np.linspace(2, 6, 100)
35+
u = Uniform(a=2, b=kernels.linear(x, a=2, b=1), reparametrization=uniform_bounds)
36+
u.pdf(x)
37+
u.rvs(x.size)
38+
39+
40+
def test_reparametrization_with_params():
41+
u = Uniform(a=3, b=5, reparametrization=uniform_bounds)
42+
v = u.with_params(a=2, b=4)
43+
assert v.a() == 2
44+
assert v.b() == 4
45+
46+
47+
def test_reparametrization_fit():
48+
u = Uniform(a=0, b=5, reparametrization=uniform_bounds)
49+
data = np.random.uniform(1, 3, 1000)
50+
fit = u.fit(data)
51+
assert 0 < fit.a() <= 1.1
52+
assert 2.9 <= fit.b() < 5
53+
assert log_likelihood(fit, data) > log_likelihood(u, data)
54+
55+
56+
def test_reparametrization_fit_fixed_param():
57+
u = Uniform(a=0, b=5, reparametrization=uniform_bounds)
58+
data = np.random.uniform(1, 3, 1000)
59+
fit = u.fit(data, a=0.5)
60+
assert fit.a() == 0.5
61+
assert 2.9 <= fit.b() < 5
62+
assert log_likelihood(fit, data) > log_likelihood(u, data)
63+
64+
65+
def test_reparametrization_fit_kernel():
66+
u = Uniform(a=0, b=kernels.constant(value=5), reparametrization=uniform_bounds)
67+
data = np.random.uniform(1, 3, 1000)
68+
fit = u.fit(data, a=0.5)
69+
assert fit.a() == 0.5
70+
assert 2.9 <= fit.b() < 5
71+
assert log_likelihood(fit, data) > log_likelihood(u, data)

0 commit comments

Comments
 (0)