Skip to content

Commit ce535b3

Browse files
AustinRochfordColCarroll
authored andcommitted
Add logistic distribution (#2588)
1 parent 9d44c39 commit ce535b3

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .continuous import SkewNormal
2727
from .continuous import Triangular
2828
from .continuous import Gumbel
29+
from .continuous import Logistic
2930
from .continuous import Interpolated
3031

3132
from .discrete import Binomial
@@ -141,6 +142,7 @@
141142
'Triangular',
142143
'DiscreteWeibull',
143144
'Gumbel',
145+
'Logistic',
144146
'Interpolated',
145147
'Bound',
146148
]

pymc3/distributions/continuous.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
'Laplace', 'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
2929
'HalfStudentT', 'StudentTpos', 'Lognormal', 'ChiSquared',
3030
'HalfNormal', 'Wald', 'Pareto', 'InverseGamma', 'ExGaussian',
31-
'VonMises', 'SkewNormal', 'Interpolated']
31+
'VonMises', 'SkewNormal', 'Logistic', 'Interpolated']
3232

3333

3434
class PositiveContinuous(Continuous):
@@ -1931,6 +1931,83 @@ def _repr_latex_(self, name=None, dist=None):
19311931
get_variable_name(beta))
19321932

19331933

1934+
class Logistic(Continuous):
1935+
R"""
1936+
Logistic log-likelihood.
1937+
1938+
.. math::
1939+
1940+
f(x \mid \mu, s) =
1941+
\frac{\exp\left(-\frac{x - \mu}{s}\right)}{s \left(1 + \exp\left(-\frac{x - \mu}{s}\right)\right)^2}
1942+
1943+
======== ==========================================
1944+
Support :math:`x \in \mathbb{R}`
1945+
Mean :math:`\mu`
1946+
Variance :math:`\frac{s^2 \pi^2}{3}`
1947+
======== ==========================================
1948+
1949+
.. plot::
1950+
1951+
import matplotlib.pyplot as plt
1952+
import numpy as np
1953+
import scipy.stats as st
1954+
x = np.linspace(-5.0, 5.0, 1000)
1955+
fig, ax = plt.subplots()
1956+
f = lambda mu, s : st.logistic.pdf(x, loc=mu, scale=s)
1957+
plot_pdf = lambda a, b : ax.plot(x, f(a,b), label=r'$\mu$={0}, $s$={1}'.format(a,b))
1958+
plot_pdf(0.0, 0.4)
1959+
plot_pdf(0.0, 1.0)
1960+
plot_pdf(0.0, 2.0)
1961+
plot_pdf(-2.0, 0.4)
1962+
plt.legend(loc='upper right', frameon=False)
1963+
ax.set(xlim=[-5,5], ylim=[0,1.2], xlabel='x', ylabel='f(x)')
1964+
plt.show()
1965+
1966+
Parameters
1967+
----------
1968+
mu : float
1969+
Mean.
1970+
s : float
1971+
Scale (s > 0).
1972+
"""
1973+
def __init__(self, mu=0., s=1., *args, **kwargs):
1974+
super(Logistic, self).__init__(*args, **kwargs)
1975+
1976+
self.mu = tt.as_tensor_variable(mu)
1977+
self.s = tt.as_tensor_variable(s)
1978+
1979+
self.mean = self.mode = mu
1980+
self.variance = s**2 * np.pi**2 / 3.
1981+
1982+
def logp(self, value):
1983+
mu = self.mu
1984+
s = self.s
1985+
1986+
return bound(
1987+
-(value - mu) / s - tt.log(s) - 2 * tt.log1p(tt.exp(-(value - mu) / s)),
1988+
s > 0
1989+
)
1990+
1991+
def random(self, point=None, size=None, repeat=None):
1992+
mu, s = draw_values([self.mu, self.s], point=point)
1993+
1994+
return generate_samples(
1995+
stats.logistic.rvs,
1996+
loc=mu, scale=s,
1997+
dist_shape=self.shape,
1998+
size=size
1999+
)
2000+
2001+
def _repr_latex_(self, name=None, dist=None):
2002+
if dist is None:
2003+
dist = self
2004+
mu = dist.mu
2005+
s = dist.s
2006+
return r'${} \sim \text{{Logistic}}(\mathit{{mu}}={}, \mathit{{s}}={})$'.format(name,
2007+
get_variable_name(mu),
2008+
get_variable_name(s))
2009+
2010+
19342011
class Interpolated(Continuous):
19352012
R"""
19362013
Univariate probability distribution defined as a linear interpolation

pymc3/tests/test_distributions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
1616
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
17-
Gumbel, Interpolated, ZeroInflatedBinomial, HalfFlat, AR1)
17+
Gumbel, Logistic, Interpolated, ZeroInflatedBinomial, HalfFlat, AR1)
1818
from ..distributions import continuous
1919
from pymc3.theanof import floatX
2020
from numpy import array, inf, log, exp
@@ -895,6 +895,11 @@ def gumbel(value, mu, beta):
895895
return floatX(sp.gumbel_r.logpdf(value, loc=mu, scale=beta))
896896
self.pymc3_matches_scipy(Gumbel, R, {'mu': R, 'beta': Rplusbig}, gumbel)
897897

898+
def test_logistic(self):
899+
self.pymc3_matches_scipy(Logistic, R, {'mu': R, 's': Rplus},
900+
lambda value, mu, s: sp.logistic.logpdf(value, mu, s),
901+
decimal=select_by_precision(float64=6, float32=1))
902+
898903
def test_multidimensional_beta_construction(self):
899904
with Model():
900905
Beta('beta', alpha=1., beta=1., shape=(10, 20))

0 commit comments

Comments
 (0)