Skip to content

Commit 167d584

Browse files
Make params_names an instance-level property (#55)
* Make fit an instance method and deprecate fit_instance * Update documentation * Make params_names an instance-level property * Feedback from copilot
1 parent e40cf42 commit 167d584

File tree

5 files changed

+92
-48
lines changed

5 files changed

+92
-48
lines changed

pykelihood/distributions/custom.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ class Exponential(ScipyDistribution):
3030
Rate parameter, by default 1.0.
3131
"""
3232

33-
params_names = ("loc", "rate")
3433
_base_module = _stats.expon
3534

3635
def __init__(self, loc=0.0, rate=1.0):
3736
super().__init__(loc, rate)
3837

38+
@property
39+
def params_names(self):
40+
return ("loc", "rate")
41+
3942
def _to_scipy_args(self, loc=None, rate=None):
4043
"""
4144
Convert to scipy arguments.
@@ -71,12 +74,15 @@ class Gamma(ScipyDistribution):
7174
Shape parameter, by default 0.0.
7275
"""
7376

74-
params_names = ("loc", "scale", "shape")
7577
_base_module = _stats.gamma
7678

7779
def __init__(self, loc=0.0, scale=1.0, shape=0.0):
7880
super().__init__(loc, scale, shape)
7981

82+
@property
83+
def params_names(self):
84+
return ("loc", "scale", "shape")
85+
8086
def _to_scipy_args(self, loc=None, scale=None, shape=None):
8187
"""
8288
Convert to scipy arguments.
@@ -116,12 +122,15 @@ class Pareto(ScipyDistribution):
116122
Shape parameter, by default 1.0.
117123
"""
118124

119-
params_names = ("loc", "scale", "alpha")
120125
_base_module = _stats.pareto
121126

122127
def __init__(self, loc=0.0, scale=1.0, alpha=1.0):
123128
super().__init__(loc, scale, alpha)
124129

130+
@property
131+
def params_names(self):
132+
return ("loc", "scale", "alpha")
133+
125134
def _to_scipy_args(self, loc=None, scale=None, alpha=None):
126135
"""
127136
Convert to scipy arguments.
@@ -163,12 +172,15 @@ class Beta(ScipyDistribution):
163172
Beta parameter, by default 1.0.
164173
"""
165174

166-
params_names = ("loc", "scale", "alpha", "beta")
167175
_base_module = _stats.beta
168176

169177
def __init__(self, loc=0.0, scale=1.0, alpha=2.0, beta=1.0):
170178
super().__init__(loc, scale, alpha, beta)
171179

180+
@property
181+
def params_names(self):
182+
return ("loc", "scale", "alpha", "beta")
183+
172184
def _to_scipy_args(self, loc=None, scale=None, alpha=None, beta=None):
173185
"""
174186
Convert to scipy arguments.
@@ -217,12 +229,15 @@ class GEV(ScipyDistribution):
217229
is `-c`.
218230
"""
219231

220-
params_names = ("loc", "scale", "shape")
221232
_base_module = _stats.genextreme
222233

223234
def __init__(self, loc=0.0, scale=1.0, shape=0.0):
224235
super().__init__(loc, scale, shape)
225236

237+
@property
238+
def params_names(self):
239+
return ("loc", "scale", "shape")
240+
226241
def lb_shape(self, data):
227242
"""
228243
Calculate the lower bound of the shape parameter.
@@ -310,12 +325,15 @@ class GPD(ScipyDistribution):
310325
Shape parameter, by default 0.0.
311326
"""
312327

313-
params_names = ("loc", "scale", "shape")
314328
_base_module = _stats.genpareto
315329

316330
def __init__(self, loc=0.0, scale=1.0, shape=0.0):
317331
super().__init__(loc, scale, shape)
318332

333+
@property
334+
def params_names(self):
335+
return ("loc", "scale", "shape")
336+
319337
def _to_scipy_args(self, loc=None, scale=None, shape=None):
320338
"""
321339
Convert to scipy arguments.
@@ -360,8 +378,6 @@ class TruncatedDistribution(Distribution):
360378
If the lower and upper bounds are equal.
361379
"""
362380

363-
params_names = ("distribution",)
364-
365381
def __init__(
366382
self, distribution: Distribution, lower_bound=-np.inf, upper_bound=np.inf
367383
):
@@ -374,7 +390,11 @@ def __init__(
374390
upper_cdf = self.distribution.cdf(self.upper_bound)
375391
self._normalizer = upper_cdf - lower_cdf
376392

377-
def _build_instance(self, **new_params):
393+
@property
394+
def params_names(self):
395+
return ("distribution",)
396+
397+
def _build_instance(self, distribution: Distribution, **new_params):
378398
"""
379399
Build a new instance with the given parameters.
380400
@@ -388,7 +408,6 @@ def _build_instance(self, **new_params):
388408
TruncatedDistribution
389409
The new instance.
390410
"""
391-
distribution = new_params.pop("distribution")
392411
if new_params:
393412
raise ValueError(f"Unexpected arguments: {new_params}")
394413
return type(self)(distribution, self.lower_bound, self.upper_bound)

pykelihood/distributions/scipy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def format_param_docstring(param: str) -> str:
4545

4646
class Wrapper(ScipyDistribution):
4747
_base_module = scipy_dist
48-
params_names = dist_params_names
4948
__doc__ = docstring
5049

5150
def __init__(self, loc=0.0, scale=1.0, **kwargs):
52-
assert self.params_names[:2] == ("loc", "scale")
53-
shape_args = self.params_names[2:]
51+
self._params_names = dist_params_names
52+
assert self._params_names[:2] == ("loc", "scale")
53+
shape_args = self._params_names[2:]
5454
for arg in shape_args:
5555
if arg not in kwargs:
5656
raise ValueError(
@@ -59,6 +59,11 @@ def __init__(self, loc=0.0, scale=1.0, **kwargs):
5959
args = [kwargs[a] for a in shape_args]
6060
super().__init__(loc, scale, *args)
6161

62+
@property
63+
def params_names(self) -> tuple[str, ...]:
64+
"""Return the names of the parameters."""
65+
return self._params_names
66+
6267
def _to_scipy_args(self, **kwargs):
6368
return {k: kwargs.get(k, getattr(self, k)()) for k in self.params_names}
6469

pykelihood/kernels.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,6 @@ def wrapped(x, **param_values) -> Kernel:
178178
return wrapper
179179

180180

181-
"""
182-
Simple kernels with one covariate
183-
"""
184-
185-
186181
@kernel(a=0.0, b=0.0)
187182
def linear(X, a, b):
188183
r"""
@@ -290,12 +285,10 @@ def exponential_ratio(X, a, b, c):
290285
array-like
291286
Exponential ratio kernel output.
292287
"""
293-
inner = a * X
294-
inner = inner / b
295-
return c * np.exp(inner)
288+
return c * np.exp(a * X / b)
296289

297290

298-
@kernel(mu=0.0, sigma=1.0, scaling=0.0)
291+
@kernel(mu=0.0, sigma=1.0, scaling=1.0)
299292
def gaussian(X, mu, sigma, scaling):
300293
r"""
301294
Gaussian kernel function.

pykelihood/parameters.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import abc
34
from collections import ChainMap
45
from collections.abc import Iterable
56
from typing import TYPE_CHECKING, Any, Callable, TypeVar
@@ -37,13 +38,11 @@ def ensure_parametrized(x: Any, constant=False) -> Parametrized:
3738
return cls(x)
3839

3940

40-
class Parametrized:
41+
class Parametrized(abc.ABC):
4142
"""
4243
Base class for parametrized objects.
4344
"""
4445

45-
params_names: tuple[str, ...]
46-
4746
def __init__(self, *params: Parametrized | Any):
4847
"""
4948
Initialize the `Parametrized` object.
@@ -55,6 +54,19 @@ def __init__(self, *params: Parametrized | Any):
5554
"""
5655
self._params = tuple(ensure_parametrized(p) for p in params)
5756

57+
@property
58+
@abc.abstractmethod
59+
def params_names(self) -> tuple[str, ...]:
60+
"""
61+
Get the names of the parameters.
62+
63+
Returns
64+
-------
65+
tuple[str, ...]
66+
The parameter names.
67+
"""
68+
raise NotImplementedError()
69+
5870
def _build_instance(self, **new_params) -> Self:
5971
"""
6072
Build a new instance with the given parameters.
@@ -72,7 +84,7 @@ def _build_instance(self, **new_params) -> Self:
7284
return type(self)(**new_params)
7385

7486
@property
75-
def params(self) -> tuple[Parametrized]:
87+
def params(self) -> tuple[Parametrized, ...]:
7688
"""
7789
Get parameters in their parametrized format, e.g. how they were defined.
7890
@@ -96,7 +108,7 @@ def param_dict(self) -> dict[str, Parametrized]:
96108
return dict(zip(self.params_names, self.params))
97109

98110
@property
99-
def flattened_params(self) -> tuple[Parametrized]:
111+
def flattened_params(self) -> tuple[Parametrized, ...]:
100112
"""
101113
Get a horizontal view of all parameters in the final state of their
102114
respective tree of dependence.
@@ -220,14 +232,6 @@ def __call__(self, *args, **kwargs):
220232
raise NotImplementedError("A generic Parametrized object has no value!")
221233

222234
def __repr__(self):
223-
"""
224-
Get the string representation of the `Parametrized` object.
225-
226-
Returns
227-
-------
228-
str
229-
The string representation.
230-
"""
231235
args = [f"{a}={v!r}" for a, v in zip(self.params_names, self._params)]
232236
return f"{type(self).__name__}({', '.join(args)})"
233237

@@ -327,6 +331,10 @@ def __init__(self, value: npt.ArrayLike) -> None:
327331
"""
328332
self._value = np.asarray(value, dtype=np.float64)
329333

334+
@property
335+
def params_names(self) -> tuple[()]:
336+
return ()
337+
330338
@property
331339
def params(self):
332340
"""
@@ -497,10 +505,22 @@ def __init__(self, f: Callable, *, fname=None, **params: Parametrized):
497505
Parameters for the function.
498506
"""
499507
super().__init__(*params.values())
500-
self.params_names = tuple(params.keys())
508+
self._params_names = tuple(params.keys())
501509
self.f = f
502510
self.fname = fname or f.__qualname__
503511

512+
@property
513+
def params_names(self) -> tuple[str, ...]:
514+
"""
515+
Get the names of the parameters.
516+
517+
Returns
518+
-------
519+
tuple[str, ...]
520+
The parameter names.
521+
"""
522+
return self._params_names
523+
504524
def __call__(self, *args, **kwargs):
505525
"""
506526
Call the parametrized function with the given arguments.

tests/test_parameters.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import pytest
24

35
from pykelihood import parameters
@@ -13,14 +15,21 @@ def test_parameter_with_params():
1315
assert p.with_params([5.0])() == 5.0
1416

1517

18+
class BareParametrized(parameters.Parametrized):
19+
def __init__(self, *p: parameters.Parametrized | float, names: list[str]):
20+
super().__init__(*p)
21+
self._params_names = names
22+
23+
@property
24+
def params_names(self):
25+
return self._params_names
26+
27+
1628
def test_flattened_params():
1729
p1 = parameters.Parameter(1)
1830
p2 = parameters.ConstantParameter(2)
19-
a = parameters.Parametrized(p1, 2)
20-
a.params_names = ("x", "m")
21-
repr(a)
22-
b = parameters.Parametrized(a, p2)
23-
b.params_names = ("y", "n")
31+
a = BareParametrized(p1, 2, names=["x", "m"])
32+
b = BareParametrized(a, p2, names=["y", "n"])
2433
assert set(a.flattened_param_dict.keys()) == {"x", "m"}
2534
assert len(a.flattened_params) == 2
2635
assert set(b.flattened_param_dict.keys()) == {"y_x", "y_m", "n"}
@@ -30,15 +39,13 @@ def test_flattened_params():
3039
def test_flattened_params_with_embedded_constant():
3140
p1 = parameters.Parameter(1)
3241
p2 = parameters.ConstantParameter(2)
33-
a = parameters.Parametrized(p1, p2)
34-
a.params_names = ("x", "m")
35-
b = parameters.Parametrized(a)
36-
b.params_names = "y"
37-
assert set(a.flattened_param_dict.keys()) == {"x", "m"}
42+
a = BareParametrized(p1, p2, names=["x", "m"])
43+
b = BareParametrized(a, names=["y"])
44+
assert set(a.flattened_param_dict) == {"x", "m"}
3845
assert len(a.flattened_params) == 2
39-
assert set(b.flattened_param_dict.keys()) == {"y_x", "y_m"}
46+
assert set(b.flattened_param_dict) == {"y_x", "y_m"}
4047
assert len(b.flattened_params) == 2
41-
assert set(b.optimisation_param_dict.keys()) == {"y_x"}
48+
assert set(b.optimisation_param_dict) == {"y_x"}
4249
assert len(b.optimisation_params) == 1
4350

4451

0 commit comments

Comments
 (0)