Skip to content

Commit 050e65d

Browse files
authored
MAINT: stats: Add custom reprs for transformed distributions (scipy#22037)
* MAINT: sort dict of parameters instead of sorting later * MAINT: improve reprs of transformed distributions to make them executable * MAINT: raise if no __repr__ override transformed dist subclass * MAINT: remove __str__ from monotonic transform * TST: Update __repr__ tests for continuous dists * TST: Add tests the reprs evaluate to correct dist * BUG: Set array priority to 1 to get reflected operators working * MAINT: summarize arrays to prevent long output
1 parent ce4ae0e commit 050e65d

File tree

2 files changed

+190
-50
lines changed

2 files changed

+190
-50
lines changed

scipy/stats/_distribution_infrastructure.py

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
def _isnull(x):
2828
return type(x) is object or x is None
2929

30-
__all__ = ['ContinuousDistribution']
30+
__all__ = ['make_distribution', 'Mixture', 'order_statistic',
31+
'truncate', 'abs', 'exp', 'log']
3132

3233
# Could add other policies for broadcasting and edge/out-of-bounds case handling
3334
# For instance, when edge case handling is known not to be needed, it's much
@@ -1482,6 +1483,7 @@ class ContinuousDistribution(_ProbabilityDistribution):
14821483
text.
14831484
14841485
"""
1486+
__array_priority__ = 1
14851487
_parameterizations = [] # type: ignore[var-annotated]
14861488

14871489
### Initialization
@@ -1501,7 +1503,8 @@ def __init__(self, *, tol=_null, validation_policy=None, cache_policy=None,
15011503
# IDEs can suggest parameter names. If there are multiple parameterizations,
15021504
# we'll need the default values of parameters to be None; this will
15031505
# filter out the parameters that were not actually specified by the user.
1504-
parameters = {key: val for key, val in parameters.items() if val is not None}
1506+
parameters = {key: val for key, val in
1507+
sorted(parameters.items()) if val is not None}
15051508
self._update_parameters(**parameters)
15061509

15071510
def _update_parameters(self, *, validation_policy=None, **params):
@@ -1701,9 +1704,7 @@ def _process_parameters(self, **params):
17011704

17021705
def _get_parameter_str(self, parameters):
17031706
# Get a string representation of the parameters like "{a, b, c}".
1704-
parameter_names_list = list(parameters.keys())
1705-
parameter_names_list.sort()
1706-
return f"{{{', '.join(parameter_names_list)}}}"
1707+
return f"{{{', '.join(parameters.keys())}}}"
17071708

17081709
def _copy_parameterization(self):
17091710
self._parameterizations = self._parameterizations.copy()
@@ -1786,25 +1787,17 @@ def __repr__(self):
17861787
r""" Returns a string representation of the distribution.
17871788
17881789
Includes the name of the distribution family, the names of the
1789-
parameters, and the broadcasted shape and result dtype of the
1790-
parameters.
1790+
parameters and the `repr` of each of their values.
1791+
17911792
17921793
"""
17931794
class_name = self.__class__.__name__
17941795
parameters = list(self._original_parameters.items())
17951796
info = []
1796-
if parameters:
1797-
parameters.sort()
1798-
if self._size <= 3:
1799-
str_parameters = [f"{symbol}={value}" for symbol, value in parameters]
1800-
str_parameters = f"{', '.join(str_parameters)}"
1801-
else:
1802-
str_parameters = f"{', '.join([symbol for symbol, _ in parameters])}"
1803-
info.append(str_parameters)
1804-
if self._shape:
1805-
info.append(f"shape={self._shape}")
1806-
if self._dtype != np.float64:
1807-
info.append(f"dtype={self._dtype}")
1797+
with np.printoptions(threshold=10):
1798+
str_parameters = [f"{symbol}={repr(value)}" for symbol, value in parameters]
1799+
str_parameters = f"{', '.join(str_parameters)}"
1800+
info.append(str_parameters)
18081801
return f"{class_name}({', '.join(info)})"
18091802

18101803
def __add__(self, loc):
@@ -1825,10 +1818,13 @@ def __pow__(self, other):
18251818
"implemented when the argument is a positive integer.")
18261819
raise NotImplementedError(message)
18271820

1828-
X = abs(self) if (other % 2 == 0) else self
1821+
# Fill in repr_pattern with the repr of self before taking abs.
1822+
# Avoids having unnecessary abs in the repr.
1823+
with np.printoptions(threshold=10):
1824+
repr_pattern = f"({repr(self)})**{repr(other)}"
1825+
X = abs(self) if other % 2 == 0 else self
18291826

1830-
# This notation for g_name is nonstandard
1831-
funcs = dict(g=lambda u: u**other, g_name=f'pow_{other}',
1827+
funcs = dict(g=lambda u: u**other, repr_pattern=repr_pattern,
18321828
h=lambda u: np.sign(u) * np.abs(u)**(1 / other),
18331829
dh=lambda u: 1/other * np.abs(u)**(1/other - 1))
18341830

@@ -1846,8 +1842,10 @@ def __rmul__(self, other):
18461842

18471843
def __rtruediv__(self, other):
18481844
a, b = self.support()
1849-
funcs = dict(g=lambda u: 1 / u, g_name='inv',
1850-
h=lambda u: 1 / u, dh=lambda u: 1 / u ** 2)
1845+
with np.printoptions(threshold=10):
1846+
funcs = dict(g=lambda u: 1 / u,
1847+
repr_pattern=f"{repr(other)}/({repr(self)})",
1848+
h=lambda u: 1 / u, dh=lambda u: 1 / u ** 2)
18511849
if np.all(a >= 0) or np.all(b <= 0):
18521850
out = MonotonicTransformedDistribution(self, **funcs, increasing=False)
18531851
else:
@@ -1860,9 +1858,11 @@ def __rtruediv__(self, other):
18601858
return out * other
18611859

18621860
def __rpow__(self, other):
1863-
funcs = dict(g=lambda u: other**u, g_name=f'{other}**',
1864-
h=lambda u: np.log(u) / np.log(other),
1865-
dh=lambda u: 1 / np.abs(u * np.log(other)))
1861+
with np.printoptions(threshold=10):
1862+
funcs = dict(g=lambda u: other**u,
1863+
h=lambda u: np.log(u) / np.log(other),
1864+
dh=lambda u: 1 / np.abs(u * np.log(other)),
1865+
repr_pattern=f"{repr(other)}**({repr(self)})")
18661866

18671867
if not np.isscalar(other) or other <= 0 or other == 1:
18681868
message = ("Raising an argument to the power of a random variable is only "
@@ -3846,9 +3846,7 @@ def _process_parameters(self, **params):
38463846
return self._dist._process_parameters(**params)
38473847

38483848
def __repr__(self):
3849-
s = super().__repr__()
3850-
return s.replace("Distribution",
3851-
self._dist.__class__.__name__)
3849+
raise NotImplementedError()
38523850

38533851

38543852
class TruncatedDistribution(TransformedDistribution):
@@ -3926,6 +3924,11 @@ def _iccdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
39263924
p_adjusted = cFb + p*np.exp(logmass)
39273925
return self._dist._iccdf_dispatch(p_adjusted, *args, **params)
39283926

3927+
def __repr__(self):
3928+
with np.printoptions(threshold=10):
3929+
return (f"truncate({repr(self._dist)}, "
3930+
f"lb={repr(self.lb)}, ub={repr(self.ub)})")
3931+
39293932

39303933
def truncate(X, lb=-np.inf, ub=np.inf):
39313934
"""Truncate the support of a random variable.
@@ -4026,6 +4029,18 @@ def _support(self, loc, scale, sign, **params):
40264029
a, b = self._itransform(a, loc, scale), self._itransform(b, loc, scale)
40274030
return np.where(sign, a, b)[()], np.where(sign, b, a)[()]
40284031

4032+
def __repr__(self):
4033+
with np.printoptions(threshold=10):
4034+
result = f"{repr(self.scale)}*{repr(self._dist)}"
4035+
if not self.loc.ndim and self.loc < 0:
4036+
result += f" - {repr(-self.loc)}"
4037+
elif (np.any(self.loc != 0)
4038+
or not np.can_cast(self.loc.dtype, self.scale.dtype)):
4039+
# We don't want to hide a zero array loc if it can cause
4040+
# a type promotion.
4041+
result += f" + {repr(self.loc)}"
4042+
return result
4043+
40294044
# Here, we override all the `_dispatch` methods rather than the public
40304045
# methods or _function methods. Why not the public methods?
40314046
# If we were to override the public methods, then other
@@ -4298,6 +4313,11 @@ def _iccdf_formula(self, p, r, n, **kwargs):
42984313
p_ = special.betainccinv(r, n-r+1, p)
42994314
return self._dist._icdf_dispatch(p_, **kwargs)
43004315

4316+
def __repr__(self):
4317+
with np.printoptions(threshold=10):
4318+
return (f"order_statistic({repr(self._dist)}, r={repr(self.r)}, "
4319+
f"n={repr(self.n)})")
4320+
43014321

43024322
def order_statistic(X, /, *, r, n):
43034323
r"""Probability distribution of an order statistic
@@ -4678,6 +4698,17 @@ def sample(self, shape=(), *, rng=None, method=None):
46784698
x = np.reshape(rng.permuted(np.concatenate(x)), shape)
46794699
return x[()]
46804700

4701+
def __repr__(self):
4702+
result = "Mixture(\n"
4703+
result += " [\n"
4704+
with np.printoptions(threshold=10):
4705+
for component in self.components:
4706+
result += f" {repr(component)},\n"
4707+
result += " ],\n"
4708+
result += f" weights={repr(self.weights)},\n"
4709+
result += ")"
4710+
return result
4711+
46814712

46824713
class MonotonicTransformedDistribution(TransformedDistribution):
46834714
r"""Distribution underlying a strictly monotonic function of a random variable
@@ -4701,14 +4732,18 @@ class MonotonicTransformedDistribution(TransformedDistribution):
47014732
increasing : bool, optional
47024733
Whether the function is strictly increasing (True, default)
47034734
or strictly decreasing (False).
4704-
g_name : str, optional
4705-
The name of the mathematical function represented by `g`,
4706-
used in `__repr__` and `__str__`. The default is ``g.__name__``.
4735+
repr_pattern : str, optional
4736+
A string pattern for determining the __repr__. The __repr__
4737+
for X will be substituted into the position where `***` appears.
4738+
For example:
4739+
``"exp(***)"`` for the repr of an exponentially transformed
4740+
distribution
4741+
The default is ``f"{g.__name__}(***)"``.
47074742
47084743
"""
47094744

47104745
def __init__(self, X, /, *args, g, h, dh, logdh=None,
4711-
increasing=True, g_name=None, **kwargs):
4746+
increasing=True, repr_pattern=None, **kwargs):
47124747
super().__init__(X, *args, **kwargs)
47134748
self._g = g
47144749
self._h = h
@@ -4734,13 +4769,11 @@ def __init__(self, X, /, *args, g, h, dh, logdh=None,
47344769
self._ilogxdf = self._dist._ilogccdf_dispatch
47354770
self._ilogcxdf = self._dist._ilogcdf_dispatch
47364771
self._increasing = increasing
4737-
self._g_name = g.__name__ if g_name is None else g_name
4772+
self._repr_pattern = repr_pattern or f"{g.__name__}(***)"
47384773

47394774
def __repr__(self):
4740-
return f"{self._g_name}({repr(self._dist)})"
4741-
4742-
def __str__(self):
4743-
return f"{self._g_name}({str(self._dist)})"
4775+
with np.printoptions(threshold=10):
4776+
return self._repr_pattern.replace("***", repr(self._dist))
47444777

47454778
def _overrides(self, method_name):
47464779
# Do not use the generic overrides of TransformedDistribution
@@ -4892,6 +4925,10 @@ def _sample_dispatch(self, sample_shape, full_shape, *,
48924925
sample_shape, full_shape, method=method, rng=rng, **params)
48934926
return np.abs(rvs)
48944927

4928+
def __repr__(self):
4929+
with np.printoptions(threshold=10):
4930+
return f"abs({repr(self._dist)})"
4931+
48954932

48964933
def abs(X, /):
48974934
r"""Absolute value of a random variable

scipy/stats/tests/test_continuous.py

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,17 @@ def test_repr_str_docs(self):
11141114
assert hasattr(stats, dist)
11151115

11161116
dist = stats.make_distribution(stats.gamma)
1117-
assert str(dist(a=2)) == "Gamma(a=2.0)"
1117+
if np.__version__ < "2":
1118+
assert str(dist(a=2)) == "Gamma(a=2.0)"
1119+
else:
1120+
assert str(dist(a=2)) == "Gamma(a=np.float64(2.0))"
11181121
assert 'Gamma' in dist.__doc__
11191122

11201123
dist = stats.make_distribution(stats.halfgennorm)
1121-
assert str(dist(beta=2)) == "HalfGeneralizedNormal(beta=2.0)"
1124+
if np.__version__ < "2":
1125+
str(dist(beta=2)) == "HalfGeneralizedNormal(beta=2.0)"
1126+
else:
1127+
assert str(dist(beta=2)) == "HalfGeneralizedNormal(beta=np.float64(2.0))"
11221128
assert 'HalfGeneralizedNormal' in dist.__doc__
11231129

11241130

@@ -1381,10 +1387,18 @@ def test_log(self):
13811387
def test_monotonic_transforms(self):
13821388
# Some tests of monotonic transforms that are better to be grouped or
13831389
# don't fit well above
1390+
13841391
X = Uniform(a=1, b=2)
1385-
assert repr(stats.log(X)) == str(stats.log(X)) == "log(Uniform(a=1.0, b=2.0))"
1386-
assert repr(1 / X) == str(1 / X) == "inv(Uniform(a=1.0, b=2.0))"
1387-
assert repr(stats.exp(X)) == str(stats.exp(X)) == "exp(Uniform(a=1.0, b=2.0))"
1392+
X_repr = (
1393+
"Uniform(a=1.0, b=2.0)" if np.__version__ < "2"
1394+
else "Uniform(a=np.float64(1.0), b=np.float64(2.0))"
1395+
)
1396+
1397+
assert repr(stats.log(X)) == str(stats.log(X)) == (
1398+
f"log({X_repr})"
1399+
)
1400+
assert repr(1 / X) == str(1 / X) == f"1/({X_repr})"
1401+
assert repr(stats.exp(X)) == str(stats.exp(X)) == f"exp({X_repr})"
13881402

13891403
X = Uniform(a=-1, b=2)
13901404
message = "Division by a random variable is only implemented when the..."
@@ -1634,17 +1648,106 @@ def test_generate_domain_support(self):
16341648
msg = _generate_domain_support(_LogUniform)
16351649
assert "accepts two parameterizations" in msg
16361650

1637-
def test_ContinuousDistribution__str__(self):
1651+
def test_ContinuousDistribution__repr__(self):
16381652
X = Uniform(a=0, b=1)
1639-
assert str(X) == "Uniform(a=0.0, b=1.0)"
1640-
1641-
assert str(X*3 + 2) == "ShiftedScaledUniform(a=0.0, b=1.0, loc=2.0, scale=3.0)"
1653+
if np.__version__ < "2":
1654+
assert repr(X) == "Uniform(a=0.0, b=1.0)"
1655+
else:
1656+
assert repr(X) == "Uniform(a=np.float64(0.0), b=np.float64(1.0))"
1657+
if np.__version__ < "2":
1658+
assert repr(X*3 + 2) == "3.0*Uniform(a=0.0, b=1.0) + 2.0"
1659+
else:
1660+
assert repr(X*3 + 2) == (
1661+
"np.float64(3.0)*Uniform(a=np.float64(0.0), b=np.float64(1.0))"
1662+
" + np.float64(2.0)"
1663+
)
16421664

16431665
X = Uniform(a=np.zeros(4), b=1)
1644-
assert str(X) == "Uniform(a, b, shape=(4,))"
1666+
assert repr(X) == "Uniform(a=array([0., 0., 0., 0.]), b=1)"
16451667

16461668
X = Uniform(a=np.zeros(4, dtype=np.float32), b=np.ones(4, dtype=np.float32))
1647-
assert str(X) == "Uniform(a, b, shape=(4,), dtype=float32)"
1669+
assert repr(X) == (
1670+
"Uniform(a=array([0., 0., 0., 0.], dtype=float32),"
1671+
" b=array([1., 1., 1., 1.], dtype=float32))"
1672+
)
1673+
1674+
1675+
class TestReprs:
1676+
U = Uniform(a=0, b=1)
1677+
V = Uniform(a=np.float32(0.0), b=np.float32(1.0))
1678+
X = Normal(mu=-1, sigma=1)
1679+
Y = Normal(mu=1, sigma=1)
1680+
Z = Normal(mu=np.zeros(1000), sigma=1)
1681+
1682+
@pytest.mark.parametrize(
1683+
"dist",
1684+
[
1685+
U,
1686+
U - np.array([1.0, 2.0]),
1687+
pytest.param(
1688+
V,
1689+
marks=pytest.mark.skipif(
1690+
np.__version__ < "2",
1691+
reason="numpy 1.x didn't have dtype in repr",
1692+
)
1693+
),
1694+
pytest.param(
1695+
np.ones(2, dtype=np.float32)*V + np.zeros(2, dtype=np.float64),
1696+
marks=pytest.mark.skipif(
1697+
np.__version__ < "2",
1698+
reason="numpy 1.x didn't have dtype in repr",
1699+
)
1700+
),
1701+
3*U + 2,
1702+
U**4,
1703+
(3*U + 2)**4,
1704+
(3*U + 2)**3,
1705+
2**U,
1706+
2**(3*U + 1),
1707+
1 / (1 + U),
1708+
stats.order_statistic(U, r=3, n=5),
1709+
stats.truncate(U, 0.2, 0.8),
1710+
stats.Mixture([X, Y], weights=[0.3, 0.7]),
1711+
abs(U),
1712+
stats.exp(U),
1713+
stats.log(1 + U),
1714+
np.array([1.0, 2.0])*U + np.array([2.0, 3.0]),
1715+
]
1716+
)
1717+
def test_executable(self, dist):
1718+
# Test that reprs actually evaluate to proper distribution
1719+
# provided relevant imports are made.
1720+
from numpy import array # noqa: F401
1721+
from numpy import float32 # noqa: F401
1722+
from scipy.stats import abs, exp, log, order_statistic, truncate # noqa: F401
1723+
from scipy.stats import Mixture, Normal # noqa: F401
1724+
from scipy.stats._new_distributions import Uniform # noqa: F401
1725+
new_dist = eval(repr(dist))
1726+
# A basic check that the distributions are the same
1727+
sample1 = dist.sample(shape=10, rng=1234)
1728+
sample2 = new_dist.sample(shape=10, rng=1234)
1729+
assert_equal(sample1, sample2)
1730+
assert sample1.dtype is sample2.dtype
1731+
1732+
@pytest.mark.parametrize(
1733+
"dist",
1734+
[
1735+
Z,
1736+
np.full(1000, 2.0) * X + 1.0,
1737+
2.0 * X + np.full(1000, 1.0),
1738+
np.full(1000, 2.0) * X + 1.0,
1739+
stats.truncate(Z, -1, 1),
1740+
stats.truncate(Z, -np.ones(1000), np.ones(1000)),
1741+
stats.order_statistic(X, r=np.arange(1, 1000), n=1000),
1742+
Z**2,
1743+
1.0 / (1 + stats.exp(Z)),
1744+
2**Z,
1745+
]
1746+
)
1747+
def test_not_too_long(self, dist):
1748+
# Tests that array summarization is working to ensure reprs aren't too long.
1749+
# None of the reprs above will be executable.
1750+
assert len(repr(dist)) < 250
16481751

16491752

16501753
class MixedDist(ContinuousDistribution):

0 commit comments

Comments
 (0)