Skip to content

Commit 505bc95

Browse files
Cleanup dynamic distribution class creation (#47)
* Cleanup dynamic distribution class creation * Rename Inner to Wrapper
1 parent 8386658 commit 505bc95

File tree

3 files changed

+139
-141
lines changed

3 files changed

+139
-141
lines changed
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from pykelihood.distributions.scipy import *
2-
from pykelihood.distributions.custom import *
1+
from pykelihood.distributions import custom, scipy
32
from pykelihood.distributions.base import *
4-
from pykelihood.distributions import scipy
5-
from pykelihood.distributions import custom
3+
from pykelihood.distributions.custom import *
4+
from pykelihood.distributions.scipy import *
65

76
__all__ = [*scipy.__all__, *custom.__all__, "Distribution", "ScipyDistribution"]

pykelihood/distributions/custom.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import numpy as np
44
from scipy import stats as _stats
55

6-
from pykelihood.distributions.base import Distribution
7-
from pykelihood.distributions.base import ScipyDistribution
6+
from pykelihood.distributions.base import Distribution, ScipyDistribution
87
from pykelihood.generic_types import Obs
98
from pykelihood.utils import ifnone
109

@@ -463,7 +462,7 @@ def rvs(self, size: int, *args, **kwargs):
463462
np.ndarray
464463
Random variates.
465464
"""
466-
u = Uniform(
465+
u = _stats.uniform(
467466
self.distribution.cdf(self.lower_bound),
468467
self.distribution.cdf(self.upper_bound),
469468
)

pykelihood/distributions/scipy.py

Lines changed: 134 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def _name_from_scipy_dist(scipy_dist: stats.rv_continuous) -> str:
1313
return "".join(map(str.capitalize, scipy_dist_name.split("_")))
1414

1515

16-
def _wrap_scipy_distribution(
16+
def wrap_scipy_distribution(
1717
scipy_dist: stats.rv_continuous,
1818
) -> type[ScipyDistribution]:
1919
"""Wrap a scipy distribution class to create a ScipyDistribution subclass."""
2020
scipy_dist_name = type(scipy_dist).__name__.removesuffix("_gen")
2121
clean_dist_name = _name_from_scipy_dist(scipy_dist)
22-
params_names = ("loc", "scale") + tuple(
22+
dist_params_names = ("loc", "scale") + tuple(
2323
scipy_dist.shapes.split(", ") if scipy_dist.shapes else ()
2424
)
2525

@@ -40,151 +40,151 @@ def format_param_docstring(param: str) -> str:
4040
Shape parameter. See the SciPy documentation for the {scipy_dist_name} distribution for details.\
4141
"""
4242

43-
for param in params_names[2:]:
43+
for param in dist_params_names[2:]:
4444
docstring += format_param_docstring(param)
4545

46-
def __init__(self, loc=0.0, scale=1.0, **kwargs):
47-
shape_args = params_names[2:]
48-
for arg in shape_args:
49-
if arg not in kwargs:
50-
raise ValueError(f"Missing shape parameter: {arg}")
51-
args = [kwargs[a] for a in shape_args]
52-
ScipyDistribution.__init__(self, loc, scale, *args)
46+
class Wrapper(ScipyDistribution):
47+
_base_module = scipy_dist
48+
params_names = dist_params_names
49+
__doc__ = docstring
5350

54-
def _to_scipy_args(self, **kwargs):
55-
return {k: kwargs.get(k, getattr(self, k)()) for k in self.params_names}
51+
def __init__(self, loc=0.0, scale=1.0, **kwargs):
52+
assert self.params_names[:2] == ("loc", "scale")
53+
shape_args = self.params_names[2:]
54+
for arg in shape_args:
55+
if arg not in kwargs:
56+
raise ValueError(
57+
f"Missing shape parameter `{arg}` when initializing {type(self).__name__} distribution."
58+
)
59+
args = [kwargs[a] for a in shape_args]
60+
super().__init__(loc, scale, *args)
5661

57-
return type(
58-
clean_dist_name,
59-
(ScipyDistribution,),
60-
{
61-
"_base_module": scipy_dist,
62-
"params_names": params_names,
63-
"__init__": __init__,
64-
"_to_scipy_args": _to_scipy_args,
65-
"__doc__": docstring,
66-
},
67-
)
62+
def _to_scipy_args(self, **kwargs):
63+
return {k: kwargs.get(k, getattr(self, k)()) for k in self.params_names}
64+
65+
Wrapper.__name__ = clean_dist_name
66+
Wrapper.__qualname__ = f"{Wrapper.__module__}.{Wrapper.__name__}"
67+
return Wrapper
6868

6969

70-
Alpha = _wrap_scipy_distribution(stats.alpha)
71-
Anglit = _wrap_scipy_distribution(stats.anglit)
72-
Arcsine = _wrap_scipy_distribution(stats.arcsine)
73-
Argus = _wrap_scipy_distribution(stats.argus)
74-
# Beta = _wrap_scipy_distribution(stats.beta)
75-
Betaprime = _wrap_scipy_distribution(stats.betaprime)
76-
Bradford = _wrap_scipy_distribution(stats.bradford)
77-
Burr = _wrap_scipy_distribution(stats.burr)
78-
Burr12 = _wrap_scipy_distribution(stats.burr12)
79-
Cauchy = _wrap_scipy_distribution(stats.cauchy)
80-
Chi = _wrap_scipy_distribution(stats.chi)
81-
Chi2 = _wrap_scipy_distribution(stats.chi2)
82-
Cosine = _wrap_scipy_distribution(stats.cosine)
83-
Crystalball = _wrap_scipy_distribution(stats.crystalball)
84-
Dgamma = _wrap_scipy_distribution(stats.dgamma)
85-
Dweibull = _wrap_scipy_distribution(stats.dweibull)
86-
Erlang = _wrap_scipy_distribution(stats.erlang)
87-
Expon = _wrap_scipy_distribution(stats.expon)
88-
Exponnorm = _wrap_scipy_distribution(stats.exponnorm)
89-
Exponpow = _wrap_scipy_distribution(stats.exponpow)
90-
Exponweib = _wrap_scipy_distribution(stats.exponweib)
91-
F = _wrap_scipy_distribution(stats.f)
92-
Fatiguelife = _wrap_scipy_distribution(stats.fatiguelife)
93-
Fisk = _wrap_scipy_distribution(stats.fisk)
94-
Foldcauchy = _wrap_scipy_distribution(stats.foldcauchy)
95-
Foldnorm = _wrap_scipy_distribution(stats.foldnorm)
96-
# Gamma = _wrap_scipy_distribution(stats.gamma)
97-
Gausshyper = _wrap_scipy_distribution(stats.gausshyper)
98-
Genexpon = _wrap_scipy_distribution(stats.genexpon)
99-
Genextreme = _wrap_scipy_distribution(stats.genextreme)
100-
Gengamma = _wrap_scipy_distribution(stats.gengamma)
101-
Genhalflogistic = _wrap_scipy_distribution(stats.genhalflogistic)
102-
Genhyperbolic = _wrap_scipy_distribution(stats.genhyperbolic)
103-
Geninvgauss = _wrap_scipy_distribution(stats.geninvgauss)
104-
Genlogistic = _wrap_scipy_distribution(stats.genlogistic)
105-
Gennorm = _wrap_scipy_distribution(stats.gennorm)
106-
Genpareto = _wrap_scipy_distribution(stats.genpareto)
107-
Gibrat = _wrap_scipy_distribution(stats.gibrat)
108-
Gompertz = _wrap_scipy_distribution(stats.gompertz)
109-
GumbelL = _wrap_scipy_distribution(stats.gumbel_l)
110-
GumbelR = _wrap_scipy_distribution(stats.gumbel_r)
111-
Halfcauchy = _wrap_scipy_distribution(stats.halfcauchy)
112-
Halfgennorm = _wrap_scipy_distribution(stats.halfgennorm)
113-
Halflogistic = _wrap_scipy_distribution(stats.halflogistic)
114-
Halfnorm = _wrap_scipy_distribution(stats.halfnorm)
115-
Hypsecant = _wrap_scipy_distribution(stats.hypsecant)
116-
Invgamma = _wrap_scipy_distribution(stats.invgamma)
117-
Invgauss = _wrap_scipy_distribution(stats.invgauss)
118-
Invweibull = _wrap_scipy_distribution(stats.invweibull)
119-
JfSkewT = _wrap_scipy_distribution(stats.jf_skew_t)
120-
Johnsonsb = _wrap_scipy_distribution(stats.johnsonsb)
121-
Johnsonsu = _wrap_scipy_distribution(stats.johnsonsu)
122-
Kappa3 = _wrap_scipy_distribution(stats.kappa3)
123-
Kappa4 = _wrap_scipy_distribution(stats.kappa4)
124-
Ksone = _wrap_scipy_distribution(stats.ksone)
125-
Kstwo = _wrap_scipy_distribution(stats.kstwo)
126-
Kstwobign = _wrap_scipy_distribution(stats.kstwobign)
127-
Laplace = _wrap_scipy_distribution(stats.laplace)
128-
LaplaceAsymmetric = _wrap_scipy_distribution(stats.laplace_asymmetric)
129-
Levy = _wrap_scipy_distribution(stats.levy)
130-
LevyL = _wrap_scipy_distribution(stats.levy_l)
131-
LevyStable = _wrap_scipy_distribution(stats.levy_stable)
132-
Loggamma = _wrap_scipy_distribution(stats.loggamma)
133-
Logistic = _wrap_scipy_distribution(stats.logistic)
134-
Loglaplace = _wrap_scipy_distribution(stats.loglaplace)
135-
Lognorm = _wrap_scipy_distribution(stats.lognorm)
136-
Lomax = _wrap_scipy_distribution(stats.lomax)
137-
Maxwell = _wrap_scipy_distribution(stats.maxwell)
138-
Mielke = _wrap_scipy_distribution(stats.mielke)
139-
Moyal = _wrap_scipy_distribution(stats.moyal)
140-
Nakagami = _wrap_scipy_distribution(stats.nakagami)
141-
Ncf = _wrap_scipy_distribution(stats.ncf)
142-
Nct = _wrap_scipy_distribution(stats.nct)
143-
Ncx2 = _wrap_scipy_distribution(stats.ncx2)
144-
Norm = _wrap_scipy_distribution(stats.norm)
70+
Alpha = wrap_scipy_distribution(stats.alpha)
71+
Anglit = wrap_scipy_distribution(stats.anglit)
72+
Arcsine = wrap_scipy_distribution(stats.arcsine)
73+
Argus = wrap_scipy_distribution(stats.argus)
74+
# Beta = wrap_scipy_distribution(stats.beta)
75+
Betaprime = wrap_scipy_distribution(stats.betaprime)
76+
Bradford = wrap_scipy_distribution(stats.bradford)
77+
Burr = wrap_scipy_distribution(stats.burr)
78+
Burr12 = wrap_scipy_distribution(stats.burr12)
79+
Cauchy = wrap_scipy_distribution(stats.cauchy)
80+
Chi = wrap_scipy_distribution(stats.chi)
81+
Chi2 = wrap_scipy_distribution(stats.chi2)
82+
Cosine = wrap_scipy_distribution(stats.cosine)
83+
Crystalball = wrap_scipy_distribution(stats.crystalball)
84+
Dgamma = wrap_scipy_distribution(stats.dgamma)
85+
Dweibull = wrap_scipy_distribution(stats.dweibull)
86+
Erlang = wrap_scipy_distribution(stats.erlang)
87+
Expon = wrap_scipy_distribution(stats.expon)
88+
Exponnorm = wrap_scipy_distribution(stats.exponnorm)
89+
Exponpow = wrap_scipy_distribution(stats.exponpow)
90+
Exponweib = wrap_scipy_distribution(stats.exponweib)
91+
F = wrap_scipy_distribution(stats.f)
92+
Fatiguelife = wrap_scipy_distribution(stats.fatiguelife)
93+
Fisk = wrap_scipy_distribution(stats.fisk)
94+
Foldcauchy = wrap_scipy_distribution(stats.foldcauchy)
95+
Foldnorm = wrap_scipy_distribution(stats.foldnorm)
96+
# Gamma = wrap_scipy_distribution(stats.gamma)
97+
Gausshyper = wrap_scipy_distribution(stats.gausshyper)
98+
Genexpon = wrap_scipy_distribution(stats.genexpon)
99+
Genextreme = wrap_scipy_distribution(stats.genextreme)
100+
Gengamma = wrap_scipy_distribution(stats.gengamma)
101+
Genhalflogistic = wrap_scipy_distribution(stats.genhalflogistic)
102+
Genhyperbolic = wrap_scipy_distribution(stats.genhyperbolic)
103+
Geninvgauss = wrap_scipy_distribution(stats.geninvgauss)
104+
Genlogistic = wrap_scipy_distribution(stats.genlogistic)
105+
Gennorm = wrap_scipy_distribution(stats.gennorm)
106+
Genpareto = wrap_scipy_distribution(stats.genpareto)
107+
Gibrat = wrap_scipy_distribution(stats.gibrat)
108+
Gompertz = wrap_scipy_distribution(stats.gompertz)
109+
GumbelL = wrap_scipy_distribution(stats.gumbel_l)
110+
GumbelR = wrap_scipy_distribution(stats.gumbel_r)
111+
Halfcauchy = wrap_scipy_distribution(stats.halfcauchy)
112+
Halfgennorm = wrap_scipy_distribution(stats.halfgennorm)
113+
Halflogistic = wrap_scipy_distribution(stats.halflogistic)
114+
Halfnorm = wrap_scipy_distribution(stats.halfnorm)
115+
Hypsecant = wrap_scipy_distribution(stats.hypsecant)
116+
Invgamma = wrap_scipy_distribution(stats.invgamma)
117+
Invgauss = wrap_scipy_distribution(stats.invgauss)
118+
Invweibull = wrap_scipy_distribution(stats.invweibull)
119+
JfSkewT = wrap_scipy_distribution(stats.jf_skew_t)
120+
Johnsonsb = wrap_scipy_distribution(stats.johnsonsb)
121+
Johnsonsu = wrap_scipy_distribution(stats.johnsonsu)
122+
Kappa3 = wrap_scipy_distribution(stats.kappa3)
123+
Kappa4 = wrap_scipy_distribution(stats.kappa4)
124+
Ksone = wrap_scipy_distribution(stats.ksone)
125+
Kstwo = wrap_scipy_distribution(stats.kstwo)
126+
Kstwobign = wrap_scipy_distribution(stats.kstwobign)
127+
Laplace = wrap_scipy_distribution(stats.laplace)
128+
LaplaceAsymmetric = wrap_scipy_distribution(stats.laplace_asymmetric)
129+
Levy = wrap_scipy_distribution(stats.levy)
130+
LevyL = wrap_scipy_distribution(stats.levy_l)
131+
LevyStable = wrap_scipy_distribution(stats.levy_stable)
132+
Loggamma = wrap_scipy_distribution(stats.loggamma)
133+
Logistic = wrap_scipy_distribution(stats.logistic)
134+
Loglaplace = wrap_scipy_distribution(stats.loglaplace)
135+
Lognorm = wrap_scipy_distribution(stats.lognorm)
136+
Lomax = wrap_scipy_distribution(stats.lomax)
137+
Maxwell = wrap_scipy_distribution(stats.maxwell)
138+
Mielke = wrap_scipy_distribution(stats.mielke)
139+
Moyal = wrap_scipy_distribution(stats.moyal)
140+
Nakagami = wrap_scipy_distribution(stats.nakagami)
141+
Ncf = wrap_scipy_distribution(stats.ncf)
142+
Nct = wrap_scipy_distribution(stats.nct)
143+
Ncx2 = wrap_scipy_distribution(stats.ncx2)
144+
Norm = wrap_scipy_distribution(stats.norm)
145145
Normal = Norm # alias for backward compatibility
146-
Norminvgauss = _wrap_scipy_distribution(stats.norminvgauss)
147-
# Pareto = _wrap_scipy_distribution(stats.pareto)
148-
Pearson3 = _wrap_scipy_distribution(stats.pearson3)
149-
Powerlaw = _wrap_scipy_distribution(stats.powerlaw)
150-
Powerlognorm = _wrap_scipy_distribution(stats.powerlognorm)
151-
Powernorm = _wrap_scipy_distribution(stats.powernorm)
152-
Rayleigh = _wrap_scipy_distribution(stats.rayleigh)
153-
Rdist = _wrap_scipy_distribution(stats.rdist)
154-
Recipinvgauss = _wrap_scipy_distribution(stats.recipinvgauss)
155-
Loguniform = _wrap_scipy_distribution(stats.loguniform)
156-
Reciprocal = _wrap_scipy_distribution(stats.reciprocal)
157-
RelBreitwigner = _wrap_scipy_distribution(stats.rel_breitwigner)
158-
Rice = _wrap_scipy_distribution(stats.rice)
159-
Semicircular = _wrap_scipy_distribution(stats.semicircular)
160-
Skewcauchy = _wrap_scipy_distribution(stats.skewcauchy)
161-
Skewnorm = _wrap_scipy_distribution(stats.skewnorm)
162-
StudentizedRange = _wrap_scipy_distribution(stats.studentized_range)
163-
T = _wrap_scipy_distribution(stats.t)
164-
Trapezoid = _wrap_scipy_distribution(stats.trapezoid)
146+
Norminvgauss = wrap_scipy_distribution(stats.norminvgauss)
147+
# Pareto = wrap_scipy_distribution(stats.pareto)
148+
Pearson3 = wrap_scipy_distribution(stats.pearson3)
149+
Powerlaw = wrap_scipy_distribution(stats.powerlaw)
150+
Powerlognorm = wrap_scipy_distribution(stats.powerlognorm)
151+
Powernorm = wrap_scipy_distribution(stats.powernorm)
152+
Rayleigh = wrap_scipy_distribution(stats.rayleigh)
153+
Rdist = wrap_scipy_distribution(stats.rdist)
154+
Recipinvgauss = wrap_scipy_distribution(stats.recipinvgauss)
155+
Loguniform = wrap_scipy_distribution(stats.loguniform)
156+
Reciprocal = wrap_scipy_distribution(stats.reciprocal)
157+
RelBreitwigner = wrap_scipy_distribution(stats.rel_breitwigner)
158+
Rice = wrap_scipy_distribution(stats.rice)
159+
Semicircular = wrap_scipy_distribution(stats.semicircular)
160+
Skewcauchy = wrap_scipy_distribution(stats.skewcauchy)
161+
Skewnorm = wrap_scipy_distribution(stats.skewnorm)
162+
StudentizedRange = wrap_scipy_distribution(stats.studentized_range)
163+
T = wrap_scipy_distribution(stats.t)
164+
Trapezoid = wrap_scipy_distribution(stats.trapezoid)
165165
Trapz = Trapezoid
166-
Triang = _wrap_scipy_distribution(stats.triang)
167-
Truncexpon = _wrap_scipy_distribution(stats.truncexpon)
168-
Truncnorm = _wrap_scipy_distribution(stats.truncnorm)
169-
Truncpareto = _wrap_scipy_distribution(stats.truncpareto)
170-
TruncweibullMin = _wrap_scipy_distribution(stats.truncweibull_min)
171-
Tukeylambda = _wrap_scipy_distribution(stats.tukeylambda)
172-
Uniform = _wrap_scipy_distribution(stats.uniform)
173-
Vonmises = _wrap_scipy_distribution(stats.vonmises)
174-
VonmisesLine = _wrap_scipy_distribution(stats.vonmises_line)
175-
Wald = _wrap_scipy_distribution(stats.wald)
176-
WeibullMax = _wrap_scipy_distribution(stats.weibull_max)
177-
WeibullMin = _wrap_scipy_distribution(stats.weibull_min)
178-
Wrapcauchy = _wrap_scipy_distribution(stats.wrapcauchy)
166+
Triang = wrap_scipy_distribution(stats.triang)
167+
Truncexpon = wrap_scipy_distribution(stats.truncexpon)
168+
Truncnorm = wrap_scipy_distribution(stats.truncnorm)
169+
Truncpareto = wrap_scipy_distribution(stats.truncpareto)
170+
TruncweibullMin = wrap_scipy_distribution(stats.truncweibull_min)
171+
Tukeylambda = wrap_scipy_distribution(stats.tukeylambda)
172+
Uniform = wrap_scipy_distribution(stats.uniform)
173+
Vonmises = wrap_scipy_distribution(stats.vonmises)
174+
VonmisesLine = wrap_scipy_distribution(stats.vonmises_line)
175+
Wald = wrap_scipy_distribution(stats.wald)
176+
WeibullMax = wrap_scipy_distribution(stats.weibull_max)
177+
WeibullMin = wrap_scipy_distribution(stats.weibull_min)
178+
Wrapcauchy = wrap_scipy_distribution(stats.wrapcauchy)
179179

180180
if Version(scipy.__version__) >= Version("1.15.0"):
181-
DparetoLognorm = _wrap_scipy_distribution(stats.dpareto_lognorm)
182-
Landau = _wrap_scipy_distribution(stats.landau)
183-
Irwinhall = _wrap_scipy_distribution(stats.irwinhall)
181+
DparetoLognorm = wrap_scipy_distribution(stats.dpareto_lognorm)
182+
Landau = wrap_scipy_distribution(stats.landau)
183+
Irwinhall = wrap_scipy_distribution(stats.irwinhall)
184184

185185
__all__ = [
186186
"_name_from_scipy_dist",
187-
"_wrap_scipy_distribution",
187+
"wrap_scipy_distribution",
188188
"Alpha",
189189
"Anglit",
190190
"Arcsine",

0 commit comments

Comments
 (0)