Skip to content

Commit e6e7c19

Browse files
Remove AvoidAbstractMixin hack (#52)
1 parent 505bc95 commit e6e7c19

File tree

1 file changed

+33
-67
lines changed

1 file changed

+33
-67
lines changed

pykelihood/distributions/base.py

Lines changed: 33 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
from abc import abstractmethod
3+
from abc import ABC, abstractmethod
44
from collections.abc import Sequence
55
from dataclasses import dataclass
6-
from functools import partial
76
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
87

98
import numpy as np
9+
from numpy.typing import ArrayLike
1010
from scipy import stats
1111
from scipy.optimize import OptimizeResult, minimize
1212

@@ -21,7 +21,7 @@
2121
SomeDistribution = TypeVar("SomeDistribution", bound="Distribution")
2222

2323

24-
class Distribution(Parametrized):
24+
class Distribution(Parametrized, ABC):
2525
"""
2626
Base class for all distributions.
2727
@@ -273,76 +273,42 @@ def __getattr__(self, item: str):
273273
return getattr(self.fitted, item)
274274

275275

276-
class AvoidAbstractMixin:
276+
class ScipyDistribution(Distribution, ABC):
277277
"""
278-
Mixin to avoid abstract methods.
279-
280-
Methods
281-
-------
282-
__getattribute__(item)
283-
Get the attribute, avoiding abstract methods.
278+
Base class for distributions based on SciPy.
284279
"""
285280

286-
def __getattribute__(self, item):
287-
x = object.__getattribute__(self, item)
288-
if (
289-
hasattr(x, "__isabstractmethod__")
290-
and x.__isabstractmethod__
291-
and hasattr(self, "__getattr__")
292-
):
293-
x = self.__getattr__(item)
294-
return x
281+
_base_module: stats.rv_continuous
295282

283+
@abstractmethod
284+
def _to_scipy_args(self, **params) -> dict[str, ArrayLike]:
285+
raise NotImplementedError
296286

297-
class ScipyDistribution(Distribution, AvoidAbstractMixin):
298-
"""
299-
Base class for distributions using scipy.
287+
def rvs(self, size=None, random_state=None, **kwargs):
288+
return self._base_module.rvs(
289+
**self._to_scipy_args(**kwargs), size=size, random_state=random_state
290+
)
300291

301-
Methods
302-
-------
303-
rvs(size=None, random_state=None, **kwargs)
304-
Generate random variates.
305-
_wrapper(f, x, **extra_args)
306-
Wrapper for scipy functions.
307-
__getattr__(item)
308-
Get the attribute, wrapping scipy functions.
309-
"""
292+
def pdf(self, x, **kwargs):
293+
return self._base_module.pdf(x, **self._to_scipy_args(**kwargs))
310294

311-
_base_module: stats.rv_continuous
295+
def cdf(self, x, **kwargs):
296+
return self._base_module.cdf(x, **self._to_scipy_args(**kwargs))
312297

313-
def rvs(self, size=None, random_state=None, **kwargs):
314-
base_rvs = getattr(self._base_module, "rvs")
315-
params = {p: kwargs.pop(p) for p in self.params_names if p in kwargs}
316-
return base_rvs(
317-
**self._to_scipy_args(**params),
318-
size=size,
319-
random_state=random_state,
320-
**kwargs,
321-
)
298+
def isf(self, q, **kwargs):
299+
return self._base_module.isf(q, **self._to_scipy_args(**kwargs))
300+
301+
def ppf(self, q, **kwargs):
302+
return self._base_module.ppf(q, **self._to_scipy_args(**kwargs))
303+
304+
def sf(self, x, **kwargs):
305+
return self._base_module.sf(x, **self._to_scipy_args(**kwargs))
306+
307+
def logpdf(self, x, **kwargs):
308+
return self._base_module.logpdf(x, **self._to_scipy_args(**kwargs))
309+
310+
def logcdf(self, x, **kwargs):
311+
return self._base_module.logcdf(x, **self._to_scipy_args(**kwargs))
322312

323-
def _wrapper(self, f, x, **extra_args):
324-
params = {}
325-
other_args = {}
326-
for key, value in extra_args.items():
327-
if key in self.params_names:
328-
params[key] = value
329-
else:
330-
other_args[key] = value
331-
return f(x, **self._to_scipy_args(**params), **other_args)
332-
333-
def __getattr__(self, item):
334-
if item not in (
335-
"pdf",
336-
"logpdf",
337-
"cdf",
338-
"logcdf",
339-
"ppf",
340-
"isf",
341-
"sf",
342-
"logsf",
343-
):
344-
return super().__getattr__(item)
345-
f = getattr(self._base_module, item)
346-
g = partial(self._wrapper, f)
347-
self.__dict__[item] = g
348-
return g
313+
def logsf(self, x, **kwargs):
314+
return self._base_module.logsf(x, **self._to_scipy_args(**kwargs))

0 commit comments

Comments
 (0)