27
27
def _isnull (x ):
28
28
return type (x ) is object or x is None
29
29
30
- __all__ = ['ContinuousDistribution' ]
30
+ __all__ = ['make_distribution' , 'Mixture' , 'order_statistic' ,
31
+ 'truncate' , 'abs' , 'exp' , 'log' ]
31
32
32
33
# Could add other policies for broadcasting and edge/out-of-bounds case handling
33
34
# For instance, when edge case handling is known not to be needed, it's much
@@ -1482,6 +1483,7 @@ class ContinuousDistribution(_ProbabilityDistribution):
1482
1483
text.
1483
1484
1484
1485
"""
1486
+ __array_priority__ = 1
1485
1487
_parameterizations = [] # type: ignore[var-annotated]
1486
1488
1487
1489
### Initialization
@@ -1501,7 +1503,8 @@ def __init__(self, *, tol=_null, validation_policy=None, cache_policy=None,
1501
1503
# IDEs can suggest parameter names. If there are multiple parameterizations,
1502
1504
# we'll need the default values of parameters to be None; this will
1503
1505
# filter out the parameters that were not actually specified by the user.
1504
- parameters = {key : val for key , val in parameters .items () if val is not None }
1506
+ parameters = {key : val for key , val in
1507
+ sorted (parameters .items ()) if val is not None }
1505
1508
self ._update_parameters (** parameters )
1506
1509
1507
1510
def _update_parameters (self , * , validation_policy = None , ** params ):
@@ -1701,9 +1704,7 @@ def _process_parameters(self, **params):
1701
1704
1702
1705
def _get_parameter_str (self , parameters ):
1703
1706
# Get a string representation of the parameters like "{a, b, c}".
1704
- parameter_names_list = list (parameters .keys ())
1705
- parameter_names_list .sort ()
1706
- return f"{{{ ', ' .join (parameter_names_list )} }}"
1707
+ return f"{{{ ', ' .join (parameters .keys ())} }}"
1707
1708
1708
1709
def _copy_parameterization (self ):
1709
1710
self ._parameterizations = self ._parameterizations .copy ()
@@ -1786,25 +1787,17 @@ def __repr__(self):
1786
1787
r""" Returns a string representation of the distribution.
1787
1788
1788
1789
Includes the name of the distribution family, the names of the
1789
- parameters, and the broadcasted shape and result dtype of the
1790
- parameters.
1790
+ parameters and the `repr` of each of their values.
1791
+
1791
1792
1792
1793
"""
1793
1794
class_name = self .__class__ .__name__
1794
1795
parameters = list (self ._original_parameters .items ())
1795
1796
info = []
1796
- if parameters :
1797
- parameters .sort ()
1798
- if self ._size <= 3 :
1799
- str_parameters = [f"{ symbol } ={ value } " for symbol , value in parameters ]
1800
- str_parameters = f"{ ', ' .join (str_parameters )} "
1801
- else :
1802
- str_parameters = f"{ ', ' .join ([symbol for symbol , _ in parameters ])} "
1803
- info .append (str_parameters )
1804
- if self ._shape :
1805
- info .append (f"shape={ self ._shape } " )
1806
- if self ._dtype != np .float64 :
1807
- info .append (f"dtype={ self ._dtype } " )
1797
+ with np .printoptions (threshold = 10 ):
1798
+ str_parameters = [f"{ symbol } ={ repr (value )} " for symbol , value in parameters ]
1799
+ str_parameters = f"{ ', ' .join (str_parameters )} "
1800
+ info .append (str_parameters )
1808
1801
return f"{ class_name } ({ ', ' .join (info )} )"
1809
1802
1810
1803
def __add__ (self , loc ):
@@ -1825,10 +1818,13 @@ def __pow__(self, other):
1825
1818
"implemented when the argument is a positive integer." )
1826
1819
raise NotImplementedError (message )
1827
1820
1828
- X = abs (self ) if (other % 2 == 0 ) else self
1821
+ # Fill in repr_pattern with the repr of self before taking abs.
1822
+ # Avoids having unnecessary abs in the repr.
1823
+ with np .printoptions (threshold = 10 ):
1824
+ repr_pattern = f"({ repr (self )} )**{ repr (other )} "
1825
+ X = abs (self ) if other % 2 == 0 else self
1829
1826
1830
- # This notation for g_name is nonstandard
1831
- funcs = dict (g = lambda u : u ** other , g_name = f'pow_{ other } ' ,
1827
+ funcs = dict (g = lambda u : u ** other , repr_pattern = repr_pattern ,
1832
1828
h = lambda u : np .sign (u ) * np .abs (u )** (1 / other ),
1833
1829
dh = lambda u : 1 / other * np .abs (u )** (1 / other - 1 ))
1834
1830
@@ -1846,8 +1842,10 @@ def __rmul__(self, other):
1846
1842
1847
1843
def __rtruediv__ (self , other ):
1848
1844
a , b = self .support ()
1849
- funcs = dict (g = lambda u : 1 / u , g_name = 'inv' ,
1850
- h = lambda u : 1 / u , dh = lambda u : 1 / u ** 2 )
1845
+ with np .printoptions (threshold = 10 ):
1846
+ funcs = dict (g = lambda u : 1 / u ,
1847
+ repr_pattern = f"{ repr (other )} /({ repr (self )} )" ,
1848
+ h = lambda u : 1 / u , dh = lambda u : 1 / u ** 2 )
1851
1849
if np .all (a >= 0 ) or np .all (b <= 0 ):
1852
1850
out = MonotonicTransformedDistribution (self , ** funcs , increasing = False )
1853
1851
else :
@@ -1860,9 +1858,11 @@ def __rtruediv__(self, other):
1860
1858
return out * other
1861
1859
1862
1860
def __rpow__ (self , other ):
1863
- funcs = dict (g = lambda u : other ** u , g_name = f'{ other } **' ,
1864
- h = lambda u : np .log (u ) / np .log (other ),
1865
- dh = lambda u : 1 / np .abs (u * np .log (other )))
1861
+ with np .printoptions (threshold = 10 ):
1862
+ funcs = dict (g = lambda u : other ** u ,
1863
+ h = lambda u : np .log (u ) / np .log (other ),
1864
+ dh = lambda u : 1 / np .abs (u * np .log (other )),
1865
+ repr_pattern = f"{ repr (other )} **({ repr (self )} )" )
1866
1866
1867
1867
if not np .isscalar (other ) or other <= 0 or other == 1 :
1868
1868
message = ("Raising an argument to the power of a random variable is only "
@@ -3846,9 +3846,7 @@ def _process_parameters(self, **params):
3846
3846
return self ._dist ._process_parameters (** params )
3847
3847
3848
3848
def __repr__ (self ):
3849
- s = super ().__repr__ ()
3850
- return s .replace ("Distribution" ,
3851
- self ._dist .__class__ .__name__ )
3849
+ raise NotImplementedError ()
3852
3850
3853
3851
3854
3852
class TruncatedDistribution (TransformedDistribution ):
@@ -3926,6 +3924,11 @@ def _iccdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
3926
3924
p_adjusted = cFb + p * np .exp (logmass )
3927
3925
return self ._dist ._iccdf_dispatch (p_adjusted , * args , ** params )
3928
3926
3927
+ def __repr__ (self ):
3928
+ with np .printoptions (threshold = 10 ):
3929
+ return (f"truncate({ repr (self ._dist )} , "
3930
+ f"lb={ repr (self .lb )} , ub={ repr (self .ub )} )" )
3931
+
3929
3932
3930
3933
def truncate (X , lb = - np .inf , ub = np .inf ):
3931
3934
"""Truncate the support of a random variable.
@@ -4026,6 +4029,18 @@ def _support(self, loc, scale, sign, **params):
4026
4029
a , b = self ._itransform (a , loc , scale ), self ._itransform (b , loc , scale )
4027
4030
return np .where (sign , a , b )[()], np .where (sign , b , a )[()]
4028
4031
4032
+ def __repr__ (self ):
4033
+ with np .printoptions (threshold = 10 ):
4034
+ result = f"{ repr (self .scale )} *{ repr (self ._dist )} "
4035
+ if not self .loc .ndim and self .loc < 0 :
4036
+ result += f" - { repr (- self .loc )} "
4037
+ elif (np .any (self .loc != 0 )
4038
+ or not np .can_cast (self .loc .dtype , self .scale .dtype )):
4039
+ # We don't want to hide a zero array loc if it can cause
4040
+ # a type promotion.
4041
+ result += f" + { repr (self .loc )} "
4042
+ return result
4043
+
4029
4044
# Here, we override all the `_dispatch` methods rather than the public
4030
4045
# methods or _function methods. Why not the public methods?
4031
4046
# If we were to override the public methods, then other
@@ -4298,6 +4313,11 @@ def _iccdf_formula(self, p, r, n, **kwargs):
4298
4313
p_ = special .betainccinv (r , n - r + 1 , p )
4299
4314
return self ._dist ._icdf_dispatch (p_ , ** kwargs )
4300
4315
4316
+ def __repr__ (self ):
4317
+ with np .printoptions (threshold = 10 ):
4318
+ return (f"order_statistic({ repr (self ._dist )} , r={ repr (self .r )} , "
4319
+ f"n={ repr (self .n )} )" )
4320
+
4301
4321
4302
4322
def order_statistic (X , / , * , r , n ):
4303
4323
r"""Probability distribution of an order statistic
@@ -4678,6 +4698,17 @@ def sample(self, shape=(), *, rng=None, method=None):
4678
4698
x = np .reshape (rng .permuted (np .concatenate (x )), shape )
4679
4699
return x [()]
4680
4700
4701
+ def __repr__ (self ):
4702
+ result = "Mixture(\n "
4703
+ result += " [\n "
4704
+ with np .printoptions (threshold = 10 ):
4705
+ for component in self .components :
4706
+ result += f" { repr (component )} ,\n "
4707
+ result += " ],\n "
4708
+ result += f" weights={ repr (self .weights )} ,\n "
4709
+ result += ")"
4710
+ return result
4711
+
4681
4712
4682
4713
class MonotonicTransformedDistribution (TransformedDistribution ):
4683
4714
r"""Distribution underlying a strictly monotonic function of a random variable
@@ -4701,14 +4732,18 @@ class MonotonicTransformedDistribution(TransformedDistribution):
4701
4732
increasing : bool, optional
4702
4733
Whether the function is strictly increasing (True, default)
4703
4734
or strictly decreasing (False).
4704
- g_name : str, optional
4705
- The name of the mathematical function represented by `g`,
4706
- used in `__repr__` and `__str__`. The default is ``g.__name__``.
4735
+ repr_pattern : str, optional
4736
+ A string pattern for determining the __repr__. The __repr__
4737
+ for X will be substituted into the position where `***` appears.
4738
+ For example:
4739
+ ``"exp(***)"`` for the repr of an exponentially transformed
4740
+ distribution
4741
+ The default is ``f"{g.__name__}(***)"``.
4707
4742
4708
4743
"""
4709
4744
4710
4745
def __init__ (self , X , / , * args , g , h , dh , logdh = None ,
4711
- increasing = True , g_name = None , ** kwargs ):
4746
+ increasing = True , repr_pattern = None , ** kwargs ):
4712
4747
super ().__init__ (X , * args , ** kwargs )
4713
4748
self ._g = g
4714
4749
self ._h = h
@@ -4734,13 +4769,11 @@ def __init__(self, X, /, *args, g, h, dh, logdh=None,
4734
4769
self ._ilogxdf = self ._dist ._ilogccdf_dispatch
4735
4770
self ._ilogcxdf = self ._dist ._ilogcdf_dispatch
4736
4771
self ._increasing = increasing
4737
- self ._g_name = g .__name__ if g_name is None else g_name
4772
+ self ._repr_pattern = repr_pattern or f" { g .__name__ } (***)"
4738
4773
4739
4774
def __repr__ (self ):
4740
- return f"{ self ._g_name } ({ repr (self ._dist )} )"
4741
-
4742
- def __str__ (self ):
4743
- return f"{ self ._g_name } ({ str (self ._dist )} )"
4775
+ with np .printoptions (threshold = 10 ):
4776
+ return self ._repr_pattern .replace ("***" , repr (self ._dist ))
4744
4777
4745
4778
def _overrides (self , method_name ):
4746
4779
# Do not use the generic overrides of TransformedDistribution
@@ -4892,6 +4925,10 @@ def _sample_dispatch(self, sample_shape, full_shape, *,
4892
4925
sample_shape , full_shape , method = method , rng = rng , ** params )
4893
4926
return np .abs (rvs )
4894
4927
4928
+ def __repr__ (self ):
4929
+ with np .printoptions (threshold = 10 ):
4930
+ return f"abs({ repr (self ._dist )} )"
4931
+
4895
4932
4896
4933
def abs (X , / ):
4897
4934
r"""Absolute value of a random variable
0 commit comments