Skip to content

Commit e40cf42

Browse files
Make fit an instance method and deprecate fit_instance (#53)
* Make fit an instance method and deprecate fit_instance * Update documentation
1 parent e6e7c19 commit e40cf42

File tree

12 files changed

+65
-115
lines changed

12 files changed

+65
-115
lines changed

docs/source/modules/distributions.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ some parameters have been adjusted (e.g., the GEV distribution) to align with st
2323

2424
~Distribution.cdf
2525
~Distribution.fit
26-
~Distribution.fit_instance
2726
~Distribution.inverse_cdf
2827
~Distribution.isf
2928
~Distribution.logcdf

docs/source/user_guide/fitting.rst

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@ Let's generate a larger sample from our previous object:
1414

1515
We can fit a ``Normal`` distribution to this data, which will return another ``Normal`` object:
1616

17-
>>> Normal.fit(data)
17+
>>> Normal().fit(data)
1818
Normal(loc=1.0250822420920338, scale=1.9376400770300832)
1919

2020
As you can see, the values are slightly different from the moments in the data.
2121
This is due to the fact that the ``fit`` method returns the Maximum Likelihood Estimator (MLE)
22-
for the data, and is thus the result of an optimisation (using **scipy.optimize**). Custom optimizer and arguments passed
23-
to ``scipy.optimize.minimize`` can be passed as ``kwargs`` to the ``fit`` method of any distribution.
22+
for the data, and is thus the result of an optimisation (using **scipy.optimize**).
2423

25-
The syntax ``distribution.fit(data, loc=0)`` can be used to fit the distribution to the data while keeping the ``loc``
26-
parameter null:
24+
The syntax ``distribution.fit(data, loc=1)`` can be used to fit the distribution to the data while keeping the ``loc``
25+
parameter at a fixed value, in this case 1:
2726

28-
>>> Normal.fit(data, loc=1)
27+
>>> Normal().fit(data, loc=1)
2928
Normal(loc=1.0, scale=1.9377929687500024)

docs/source/user_guide/penalty_fitting.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ apply a Lasso penalty.
1313
>>> def lassolike_score(distribution, data):
1414
... return -np.sum(distribution.logpdf(data)) + 5 * np.abs(distribution.loc())
1515
...
16-
>>> cond_fit = Normal.fit(data, score=lassolike_score)
16+
>>> cond_fit = Normal().fit(data, score=lassolike_score)
1717

1818
We then compare a fit using the standard negative log-likelihood function to the use of the Lasso-penalized likelihood.
1919

20-
>>> std_fit = Normal.fit(data)
20+
>>> std_fit = Normal().fit(data)
2121
>>> std_fit.loc.value
2222
-0.010891307380632494
2323
>>> cond_fit.loc.value

docs/source/user_guide/trend_fitting.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ array([-0.99802364, -0.99503679, -0.98900434, -0.98277981, -0.979487 ,
1212

1313
If we try to fit this without a trend, the resulting distribution will miss out on most of the information.
1414

15-
>>> Normal.fit(data)
15+
>>> Normal().fit(data)
1616
Normal(loc=-3.6462053656578005e-05, scale=0.5789668679237372)
1717

1818
Fitting a ``Normal`` distribution with a trend in the ``loc`` parameter can be done using the following piece of code:
1919

2020
>>> from pykelihood import kernels
21-
>>> Normal.fit(data, loc=kernels.linear(np.arange(365)))
21+
>>> Normal().fit(data, loc=kernels.linear(np.arange(365)))
2222
Normal(loc=linear(a=-1.0000458359290572, b=0.005494714384381866), scale=0.0010055323717468906)
2323

2424
The ``kernels`` module is flexible and can be adapted by users to support any kind of trend.

pykelihood/distributions/base.py

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from abc import ABC, abstractmethod
45
from collections.abc import Sequence
56
from dataclasses import dataclass
@@ -12,7 +13,7 @@
1213

1314
from pykelihood.generic_types import Obs
1415
from pykelihood.metrics import opposite_log_likelihood
15-
from pykelihood.parameters import ConstantParameter, Parametrized, ensure_parametrized
16+
from pykelihood.parameters import Parametrized, ensure_parametrized
1617

1718
if TYPE_CHECKING:
1819
from typing import Self
@@ -49,8 +50,6 @@ class Distribution(Parametrized, ABC):
4950
Inverse of the cumulative distribution function.
5051
fit(data: Obs, x0: Sequence[float] = None, score: Callable[["Distribution", Obs], float] = opposite_log_likelihood, scipy_args: Optional[Dict] = None, **fixed_values) -> SomeDistribution
5152
Fit the distribution to the data.
52-
fit_instance(data, score=opposite_log_likelihood, x0: Sequence[float] = None, scipy_args: Optional[Dict] = None, **fixed_values)
53-
Fit the instance to the data.
5453
"""
5554

5655
def __hash__(self):
@@ -101,15 +100,14 @@ def inverse_cdf(self, q: Obs):
101100
def _apply_constraints(self, data):
102101
return data
103102

104-
@classmethod
105103
def fit(
106-
cls: type[SomeDistribution],
104+
self,
107105
data: Obs,
108106
x0: Sequence[float] | None = None,
109107
score: Callable[[Distribution, Obs], float] = opposite_log_likelihood,
110108
scipy_args: dict | None = None,
111109
**fixed_values,
112-
) -> Fit[SomeDistribution]:
110+
) -> Fit[Self]:
113111
"""
114112
Fit the distribution to the data.
115113
@@ -128,25 +126,14 @@ def fit(
128126
129127
Returns
130128
-------
131-
The result of the fit
129+
The result of the fit. A new instance is created with the fitted parameters.
132130
"""
133-
init_parms = {}
134-
for k in cls.params_names:
135-
if k in fixed_values:
136-
v = fixed_values.pop(k)
137-
if isinstance(v, Parametrized):
138-
init_parms[k] = v
139-
else:
140-
init_parms[k] = ConstantParameter(v)
141-
# Add keyword arguments useful for object creation
142-
for k, v in fixed_values.items():
143-
if k not in init_parms:
144-
init_parms[k] = v
145-
init = cls(**init_parms)
131+
init_parms = self._process_fit_params(**fixed_values)
132+
init = type(self)(**init_parms)
146133
data = init._apply_constraints(data)
147134

148135
if x0 is None:
149-
x0 = [x.value for x in init.optimisation_params]
136+
x0 = [x() for x in init.optimisation_params]
150137
else:
151138
if len(x0) != len(init.optimisation_params):
152139
raise ValueError(
@@ -167,6 +154,14 @@ def to_minimize(x) -> float:
167154

168155
return Fit(dist, data, score, x0=x0, optimize_result=optimization_result)
169156

157+
def fit_instance(self, *args, **kwargs):
158+
warnings.warn(
159+
"fit_instance is deprecated, use fit instead",
160+
DeprecationWarning,
161+
stacklevel=2,
162+
)
163+
return self.fit(*args, **kwargs)
164+
170165
def _process_fit_params(self, **kwds):
171166
out_dict = self.param_dict.copy()
172167
to_remove = set()
@@ -195,38 +190,6 @@ def _process_fit_params(self, **kwds):
195190
out_dict[name] = value
196191
return out_dict
197192

198-
def fit_instance(
199-
self,
200-
data: Obs,
201-
score=opposite_log_likelihood,
202-
x0: Sequence[float] | None = None,
203-
scipy_args: dict | None = None,
204-
**fixed_values,
205-
) -> Fit[Self]:
206-
"""
207-
Fit the instance to the data.
208-
209-
Parameters
210-
----------
211-
data : Obs
212-
Data to fit the instance to.
213-
score : Callable[["Distribution", Obs], float], optional
214-
Scoring function, by default opposite_log_likelihood.
215-
x0 : Sequence[float], optional
216-
Initial guess for the parameters, by default None.
217-
scipy_args : Optional[Dict], optional
218-
Additional arguments for scipy.optimize.minimize, by default None.
219-
fixed_values : dict
220-
Fixed values for the parameters.
221-
222-
Returns
223-
-------
224-
Distribution
225-
Fitted instance.
226-
"""
227-
param_dict = self._process_fit_params(**fixed_values)
228-
return self.fit(data, score=score, x0=x0, scipy_args=scipy_args, **param_dict)
229-
230193

231194
@dataclass
232195
class Fit(Generic[_T]):

pykelihood/distributions/custom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def _apply_constraints(self, x):
425425
"""
426426
return x[self._valid_indices(x)]
427427

428-
def fit_instance(self, *args, **kwargs):
428+
def fit(self, *args, **kwargs):
429429
"""
430430
Fit the instance to the data.
431431
@@ -442,7 +442,7 @@ def fit_instance(self, *args, **kwargs):
442442
The fitted instance.
443443
"""
444444
kwargs.update(lower_bound=self.lower_bound, upper_bound=self.upper_bound)
445-
return super().fit_instance(*args, **kwargs)
445+
return super().fit(*args, **kwargs)
446446

447447
def rvs(self, size: int, *args, **kwargs):
448448
"""

pykelihood/parameters.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,13 @@ def param_mapping(self, only_opt=False):
166166
return results
167167

168168
@property
169-
def optimisation_params(self) -> tuple[Parametrized]:
169+
def optimisation_params(self) -> tuple[Parameter, ...]:
170170
"""
171171
Get all parameters used in the optimization.
172172
173173
Returns
174174
-------
175-
Tuple[Parametrized]
176-
The optimization parameters.
175+
The optimization parameters.
177176
"""
178177
unique = []
179178
for q in (p_ for p in self.params for p_ in p.optimisation_params):
@@ -182,13 +181,13 @@ def optimisation_params(self) -> tuple[Parametrized]:
182181
return unique
183182

184183
@property
185-
def optimisation_param_dict(self) -> dict[str, Parametrized]:
184+
def optimisation_param_dict(self) -> dict[str, Parameter]:
186185
"""
187186
Get a dictionary of optimization parameter names and their values.
188187
189188
Returns
190189
-------
191-
Dict[str, Parametrized]
190+
Dict[str, Parameter]
192191
The optimization parameter dictionary.
193192
"""
194193
p_dict = flatten_dict(self._optimisation_param_dict_helper())

pykelihood/profiler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def optimum(self):
8686
tuple
8787
A tuple containing the estimate and the score function value.
8888
"""
89-
estimate = self.distribution.fit_instance(
89+
estimate = self.distribution.fit(
9090
self.data,
9191
score=self.score_function,
9292
x0=self.x0,
@@ -148,11 +148,7 @@ def test_profile_likelihood(self, range_for_param, param):
148148
profile_ll = []
149149
params = []
150150
for x in range_for_param:
151-
pl = opt.fit_instance(
152-
self.data,
153-
score=self.score_function,
154-
**{param: x},
155-
)
151+
pl = opt.fit(self.data, score=self.score_function, **{param: x})
156152
pl_value = -self.score_function(pl, self.data)
157153
pl_value = pl_value if isinstance(pl_value, float) else pl_value[0]
158154
if np.isfinite(pl_value):
@@ -190,9 +186,7 @@ def confidence_interval(self, param: str, precision=1e-5) -> tuple[float, float]
190186
value_threshold = func - chi2.ppf(self.inference_confidence, df=1) / 2
191187

192188
def score(x: float):
193-
new_opt = opt.fit_instance(
194-
self.data, score=self.score_function, **{param: x}
195-
)
189+
new_opt = opt.fit(self.data, score=self.score_function, **{param: x})
196190
return -self.score_function(new_opt, self.data)
197191

198192
def delta_to_threshold(x: float):

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,3 @@ extend-select = [
5757
"I", # Import sorting
5858
"UP", # PyUpgrade
5959
]
60-
61-
[tool.pytest.ini_options]
62-
durations = 10
63-
verbose = true

tests/test_distributions.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,22 @@ class TestGEV:
2929
def test_fit(self, datasets):
3030
for ds in datasets:
3131
c, loc, scale = stats.genextreme.fit(ds)
32-
fit = GEV.fit(ds)
32+
fit = GEV().fit(ds)
3333
assert fit.loc() == approx(loc)
3434
assert fit.scale() == approx(scale)
3535
assert fit.shape() == approx(-c)
3636

3737
def test_fixed_values(self):
3838
data = np.random.standard_normal(1000)
39-
raw = Normal.fit(data)
39+
raw = Normal().fit(data)
4040
assert raw.loc() == approx(0.0)
4141
assert raw.scale() == approx(1.0)
42-
fixed = Normal.fit(data, loc=1.0)
42+
fixed = Normal().fit(data, loc=1.0)
4343
assert fixed.loc() == 1.0
4444

4545

4646
def test_cache():
47+
"""There is no cache anymore, the test is kept as it can still be useful."""
4748
n = Normal(0, 1)
4849
np.testing.assert_array_almost_equal(
4950
n.pdf([-1, 0, 1]), [0.24197072, 0.39894228, 0.24197072]
@@ -87,28 +88,27 @@ def test_named_with_params_partial_assignment():
8788
assert m.scale() == 3
8889

8990

90-
def test_fit_instance(dataset):
91-
std_fit = Normal.fit(dataset)
92-
instance_fit = Normal(loc=kernels.constant()).fit_instance(dataset)
93-
assert std_fit.loc() == approx(instance_fit.loc())
91+
def test_simple_fit(dataset):
92+
std_fit = Normal().fit(dataset)
93+
kernel_fit = Normal(loc=kernels.constant()).fit(dataset)
94+
assert std_fit.loc() == approx(kernel_fit.loc())
9495

9596

96-
def test_fit_instance_fixed_params(dataset):
97-
n = Normal().fit_instance(dataset, loc=5)
97+
def test_fit_fixed_param(dataset):
98+
n = Normal().fit(dataset, loc=5)
9899
assert n.loc() == 5
99100

100101

101-
def test_fit_instance_fixed_params_multi_level(dataset, linear_kernel):
102+
def test_fit_fixed_param_depth_2(dataset, linear_kernel):
102103
n = Normal(loc=linear_kernel)
103-
m = n.fit_instance(dataset, loc_a=5)
104+
m = n.fit(dataset, loc_a=5)
104105
assert m.loc.a() == 5
105106

106107

107-
def test_fit_instance_fixed_params_extra_levels(dataset):
108+
def test_fit_fixed_param_depth_3(dataset):
108109
covariate = np.arange(len(dataset))
109110
n = Normal(loc=kernels.linear(covariate, a=kernels.linear(covariate)))
110-
n.param_mapping()
111-
m = n.fit_instance(dataset, loc_a_a=5)
111+
m = n.fit(dataset, loc_a_a=5)
112112
assert m.loc.a.a() == 5
113113

114114

@@ -141,8 +141,8 @@ def test_truncated_distribution_fit():
141141
data = n.rvs(10000)
142142
trunc_data = data[data >= 0]
143143
truncated = TruncatedDistribution(Normal(), lower_bound=0)
144-
fitted_all_data = truncated.fit_instance(data)
145-
fitted_trunc = truncated.fit_instance(trunc_data)
144+
fitted_all_data = truncated.fit(data)
145+
fitted_trunc = truncated.fit(trunc_data)
146146
for p_trunc, p_all in zip(
147147
fitted_trunc.flattened_params, fitted_all_data.flattened_params
148148
):
@@ -156,27 +156,27 @@ def test_distribution_fit_with_shared_params_in_trends():
156156
"""
157157
x = np.array(np.random.uniform(size=200))
158158
y = np.array(np.random.normal(size=200))
159-
alpha0_init = 0.0
160-
alpha = Parameter(alpha0_init)
161-
n = Normal.fit(y, loc=linear(x=x, b=alpha), scale=linear(x=x, b=alpha))
159+
alpha = Parameter(0.0)
160+
n = Normal().fit(y, loc=linear(x=x, b=alpha), scale=linear(x=x, b=alpha))
162161
alpha1 = n.loc.b
163162
alpha2 = n.scale.b
164163
assert alpha1 == alpha2
165164

166165

167-
def test_fit_instance_fixing_shared_params_in_trends():
166+
def test_fit_fixing_shared_params_in_trends():
168167
"""
169-
when 2 trends in the distribution parameters share a common parameter, e.g. alpha in the below example, making one of the corresponding trend parameter constant should automatically result in the other trend parameter is constant.
168+
when 2 trends in the distribution parameters share a common parameter,
169+
e.g. alpha in the below example, making one of the corresponding trend parameter
170+
constant should automatically result in the other trend parameter being constant.
170171
"""
171172
x = np.array(np.random.uniform(size=200))
172173
y = np.array(np.random.normal(size=200))
173-
alpha0_init = 0.0
174-
alpha = Parameter(alpha0_init)
175-
n = Normal.fit(y, loc=linear(x=x, b=alpha), scale=linear(x=x, b=alpha))
174+
alpha = Parameter(0.0)
175+
n = Normal().fit(y, loc=linear(x=x, b=alpha), scale=linear(x=x, b=alpha))
176176
fixed_alpha = ConstantParameter(
177177
n.loc.b.value
178-
) # should be equal to fit.scale.a as per problem1
179-
fit_with_fixed_alpha = n.fit_instance(data=y, loc_b=fixed_alpha)
178+
) # should be equal to fit.scale.b as per previous test above
179+
fit_with_fixed_alpha = n.fit(data=y, loc_b=fixed_alpha)
180180
assert isinstance(fit_with_fixed_alpha.scale.b, ConstantParameter)
181181
assert fit_with_fixed_alpha.scale.b.value == fixed_alpha.value
182182

0 commit comments

Comments
 (0)