Skip to content

Commit 623fc91

Browse files
committed
Rename get_moment to moment
1 parent 9c2ba7a commit 623fc91

File tree

11 files changed

+140
-140
lines changed

11 files changed

+140
-140
lines changed

docs/source/contributing/developer_guide_implementing_distribution.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This guide provides an overview on how to implement a distribution for version 4
44
It is designed for developers who wish to add a new distribution to the library.
55
Users will not be aware of all this complexity and should instead make use of helper methods such as `~pymc.distributions.DensityDist`.
66

7-
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers.
7+
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `moment` methods as well as other initialization and validation helpers.
88
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.
99

1010
Here is a summary check-list of the steps needed to implement a new distribution.
@@ -13,7 +13,7 @@ Each section will be expanded below:
1313
1. Creating a new `RandomVariable` `Op`
1414
1. Implementing the corresponding `Distribution` class
1515
1. Adding tests for the new `RandomVariable`
16-
1. Adding tests for `logp` / `logcdf` and `get_moment` methods
16+
1. Adding tests for `logp` / `logcdf` and `moment` methods
1717
1. Documenting the new `Distribution`.
1818

1919
This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions.
@@ -119,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in
119119
PyMC 4.x works in a very {term}`functional <Functional Programming>` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together.
120120
In practice, they take care of:
121121

122-
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods.
122+
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `moment`, `logp` and `logcdf` methods.
123123
1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer.
124124
1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables.
125125
1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of.
@@ -154,9 +154,9 @@ class Blah(PositiveContinuous):
154154
# the rv_op needs in order to be instantiated
155155
return super().dist([param1, param2], **kwargs)
156156

157-
# get_moment returns a symbolic expression for the stable moment from which to start sampling
157+
# moment returns a symbolic expression for the stable moment from which to start sampling
158158
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`
159-
def get_moment(rv, size, param1, param2):
159+
def moment(rv, size, param1, param2):
160160
moment, _ = at.broadcast_arrays(param1, param2)
161161
if not rv_size_is_none(size):
162162
moment = at.full(size, moment)
@@ -198,25 +198,25 @@ Some notes:
198198
overriding `__new__`.
199199
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
200200
1. The `logcdf` method is not a requirement, but it's a nice plus!
201-
1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
202-
1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
201+
1. Currently only one moment is supported in the `moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
202+
1. When creating the `moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
203203

204204
For a quick check that things are working you can try the following:
205205

206206
```python
207207

208208
import pymc as pm
209-
from pymc.distributions.distribution import get_moment
209+
from pymc.distributions.distribution import moment
210210

211211
# pm.blah = pm.Normal in this example
212-
blah = pm.blah.dist(mu = 0, sigma = 1)
212+
blah = pm.blah.dist(mu=0, sigma=1)
213213

214214
# Test that the returned blah_op is still working fine
215215
blah.eval()
216216
# array(-1.01397228)
217217

218-
# Test the get_moment method
219-
get_moment(blah).eval()
218+
# Test the moment method
219+
moment(blah).eval()
220220
# array(0.)
221221

222222
# Test the logp method
@@ -367,9 +367,9 @@ def test_blah_logcdf(self):
367367

368368
```
369369

370-
## 5. Adding tests for the `get_moment` method
370+
## 5. Adding tests for the `moment` method
371371

372-
Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
372+
Tests for the `moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
373373
which checks if:
374374
1. Moments return the `expected` values
375375
1. Moments have the expected size and shape

pymc/distributions/censored.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aesara.tensor import TensorVariable
1919
from aesara.tensor.random.op import RandomVariable
2020

21-
from pymc.distributions.distribution import SymbolicDistribution, _get_moment
21+
from pymc.distributions.distribution import SymbolicDistribution, _moment
2222
from pymc.util import check_dist_not_registered
2323

2424

@@ -124,8 +124,8 @@ def graph_rvs(cls, rv):
124124
return (rv.tag.dist,)
125125

126126

127-
@_get_moment.register(Clip)
128-
def get_moment_censored(op, rv, dist, lower, upper):
127+
@_moment.register(Clip)
128+
def moment_censored(op, rv, dist, lower, upper):
129129
moment = at.switch(
130130
at.eq(lower, -np.inf),
131131
at.switch(

pymc/distributions/continuous.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def dist(cls, lower=0, upper=1, **kwargs):
305305
upper = at.as_tensor_variable(floatX(upper))
306306
return super().dist([lower, upper], **kwargs)
307307

308-
def get_moment(rv, size, lower, upper):
308+
def moment(rv, size, lower, upper):
309309
lower, upper = at.broadcast_arrays(lower, upper)
310310
moment = (lower + upper) / 2
311311
if not rv_size_is_none(size):
@@ -370,7 +370,7 @@ def dist(cls, *, size=None, **kwargs):
370370
res = super().dist([], size=size, **kwargs)
371371
return res
372372

373-
def get_moment(rv, size):
373+
def moment(rv, size):
374374
return at.zeros(size)
375375

376376
def logp(value):
@@ -438,7 +438,7 @@ def dist(cls, *, size=None, **kwargs):
438438
res = super().dist([], size=size, **kwargs)
439439
return res
440440

441-
def get_moment(rv, size):
441+
def moment(rv, size):
442442
return at.ones(size)
443443

444444
def logp(value):
@@ -556,7 +556,7 @@ def dist(cls, mu=0, sigma=None, tau=None, no_assert=False, **kwargs):
556556

557557
return super().dist([mu, sigma], **kwargs)
558558

559-
def get_moment(rv, size, mu, sigma):
559+
def moment(rv, size, mu, sigma):
560560
mu, _ = at.broadcast_arrays(mu, sigma)
561561
if not rv_size_is_none(size):
562562
mu = at.full(size, mu)
@@ -716,7 +716,7 @@ def dist(
716716
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)
717717
return super().dist([mu, sigma, lower, upper], **kwargs)
718718

719-
def get_moment(rv, size, mu, sigma, lower, upper):
719+
def moment(rv, size, mu, sigma, lower, upper):
720720
mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper)
721721
moment = at.switch(
722722
at.eq(lower, -np.inf),
@@ -865,7 +865,7 @@ def dist(cls, sigma=None, tau=None, *args, **kwargs):
865865

866866
return super().dist([0.0, sigma], **kwargs)
867867

868-
def get_moment(rv, size, loc, sigma):
868+
def moment(rv, size, loc, sigma):
869869
moment = loc + sigma
870870
if not rv_size_is_none(size):
871871
moment = at.full(size, moment)
@@ -1017,7 +1017,7 @@ def dist(
10171017

10181018
return super().dist([mu, lam, alpha], **kwargs)
10191019

1020-
def get_moment(rv, size, mu, lam, alpha):
1020+
def moment(rv, size, mu, lam, alpha):
10211021
mu, _, _ = at.broadcast_arrays(mu, lam, alpha)
10221022
if not rv_size_is_none(size):
10231023
mu = at.full(size, mu)
@@ -1225,7 +1225,7 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, *args, **kwargs):
12251225

12261226
return super().dist([alpha, beta], **kwargs)
12271227

1228-
def get_moment(rv, size, alpha, beta):
1228+
def moment(rv, size, alpha, beta):
12291229
mean = alpha / (alpha + beta)
12301230
if not rv_size_is_none(size):
12311231
mean = at.full(size, mean)
@@ -1356,7 +1356,7 @@ def dist(cls, a, b, *args, **kwargs):
13561356

13571357
return super().dist([a, b], *args, **kwargs)
13581358

1359-
def get_moment(rv, size, a, b):
1359+
def moment(rv, size, a, b):
13601360
mean = at.exp(at.log(b) + at.gammaln(1 + 1 / a) + at.gammaln(b) - at.gammaln(1 + 1 / a + b))
13611361
if not rv_size_is_none(size):
13621362
mean = at.full(size, mean)
@@ -1476,7 +1476,7 @@ def dist(cls, lam, *args, **kwargs):
14761476
# Aesara exponential op is parametrized in terms of mu (1/lam)
14771477
return super().dist([at.inv(lam)], **kwargs)
14781478

1479-
def get_moment(rv, size, mu):
1479+
def moment(rv, size, mu):
14801480
if not rv_size_is_none(size):
14811481
mu = at.full(size, mu)
14821482
return mu
@@ -1560,7 +1560,7 @@ def dist(cls, mu, b, *args, **kwargs):
15601560
assert_negative_support(b, "b", "Laplace")
15611561
return super().dist([mu, b], *args, **kwargs)
15621562

1563-
def get_moment(rv, size, mu, b):
1563+
def moment(rv, size, mu, b):
15641564
mu, _ = at.broadcast_arrays(mu, b)
15651565
if not rv_size_is_none(size):
15661566
mu = at.full(size, mu)
@@ -1671,7 +1671,7 @@ def dist(cls, b, kappa, mu=0, *args, **kwargs):
16711671

16721672
return super().dist([b, kappa, mu], *args, **kwargs)
16731673

1674-
def get_moment(rv, size, b, kappa, mu):
1674+
def moment(rv, size, b, kappa, mu):
16751675
mean = mu - (kappa - 1 / kappa) / b
16761676

16771677
if not rv_size_is_none(size):
@@ -1782,7 +1782,7 @@ def dist(cls, mu=0, sigma=None, tau=None, *args, **kwargs):
17821782

17831783
return super().dist([mu, sigma], *args, **kwargs)
17841784

1785-
def get_moment(rv, size, mu, sigma):
1785+
def moment(rv, size, mu, sigma):
17861786
mean = at.exp(mu + 0.5 * sigma**2)
17871787
if not rv_size_is_none(size):
17881788
mean = at.full(size, mean)
@@ -1907,7 +1907,7 @@ def dist(cls, nu, mu=0, lam=None, sigma=None, *args, **kwargs):
19071907

19081908
return super().dist([nu, mu, sigma], **kwargs)
19091909

1910-
def get_moment(rv, size, nu, mu, sigma):
1910+
def moment(rv, size, nu, mu, sigma):
19111911
mu, _, _ = at.broadcast_arrays(mu, nu, sigma)
19121912
if not rv_size_is_none(size):
19131913
mu = at.full(size, mu)
@@ -2025,7 +2025,7 @@ def dist(
20252025

20262026
return super().dist([alpha, m], **kwargs)
20272027

2028-
def get_moment(rv, size, alpha, m):
2028+
def moment(rv, size, alpha, m):
20292029
median = m * 2 ** (1 / alpha)
20302030
if not rv_size_is_none(size):
20312031
median = at.full(size, median)
@@ -2121,7 +2121,7 @@ def dist(cls, alpha, beta, *args, **kwargs):
21212121
assert_negative_support(beta, "beta", "Cauchy")
21222122
return super().dist([alpha, beta], **kwargs)
21232123

2124-
def get_moment(rv, size, alpha, beta):
2124+
def moment(rv, size, alpha, beta):
21252125
alpha, _ = at.broadcast_arrays(alpha, beta)
21262126
if not rv_size_is_none(size):
21272127
alpha = at.full(size, alpha)
@@ -2197,7 +2197,7 @@ def dist(cls, beta, *args, **kwargs):
21972197
assert_negative_support(beta, "beta", "HalfCauchy")
21982198
return super().dist([0.0, beta], **kwargs)
21992199

2200-
def get_moment(rv, size, loc, beta):
2200+
def moment(rv, size, loc, beta):
22012201
if not rv_size_is_none(size):
22022202
beta = at.full(size, beta)
22032203
return beta
@@ -2320,7 +2320,7 @@ def get_alpha_beta(cls, alpha=None, beta=None, mu=None, sigma=None):
23202320

23212321
return alpha, beta
23222322

2323-
def get_moment(rv, size, alpha, inv_beta):
2323+
def moment(rv, size, alpha, inv_beta):
23242324
# The Aesara `GammaRV` `Op` inverts the `beta` parameter itself
23252325
mean = alpha * inv_beta
23262326
if not rv_size_is_none(size):
@@ -2415,7 +2415,7 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, *args, **kwargs):
24152415

24162416
return super().dist([alpha, beta], **kwargs)
24172417

2418-
def get_moment(rv, size, alpha, beta):
2418+
def moment(rv, size, alpha, beta):
24192419
mean = beta / (alpha - 1.0)
24202420
mode = beta / (alpha + 1.0)
24212421
moment = at.switch(alpha > 1, mean, mode)
@@ -2525,7 +2525,7 @@ def dist(cls, nu, *args, **kwargs):
25252525
nu = at.as_tensor_variable(floatX(nu))
25262526
return super().dist([nu], *args, **kwargs)
25272527

2528-
def get_moment(rv, size, nu):
2528+
def moment(rv, size, nu):
25292529
moment = nu
25302530
if not rv_size_is_none(size):
25312531
moment = at.full(size, moment)
@@ -2620,7 +2620,7 @@ def dist(cls, alpha, beta, *args, **kwargs):
26202620

26212621
return super().dist([alpha, beta], *args, **kwargs)
26222622

2623-
def get_moment(rv, size, alpha, beta):
2623+
def moment(rv, size, alpha, beta):
26242624
mean = beta * at.gamma(1 + 1 / alpha)
26252625
if not rv_size_is_none(size):
26262626
mean = at.full(size, mean)
@@ -2739,7 +2739,7 @@ def dist(cls, nu=1, sigma=None, lam=None, *args, **kwargs):
27392739

27402740
return super().dist([nu, sigma], *args, **kwargs)
27412741

2742-
def get_moment(rv, size, nu, sigma):
2742+
def moment(rv, size, nu, sigma):
27432743
sigma, _ = at.broadcast_arrays(sigma, nu)
27442744
if not rv_size_is_none(size):
27452745
sigma = at.full(size, sigma)
@@ -2871,7 +2871,7 @@ def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs):
28712871

28722872
return super().dist([mu, sigma, nu], *args, **kwargs)
28732873

2874-
def get_moment(rv, size, mu, sigma, nu):
2874+
def moment(rv, size, mu, sigma, nu):
28752875
mu, nu, _ = at.broadcast_arrays(mu, nu, sigma)
28762876
moment = mu + nu
28772877
if not rv_size_is_none(size):
@@ -3005,7 +3005,7 @@ def dist(cls, mu=0.0, kappa=None, *args, **kwargs):
30053005
assert_negative_support(kappa, "kappa", "VonMises")
30063006
return super().dist([mu, kappa], *args, **kwargs)
30073007

3008-
def get_moment(rv, size, mu, kappa):
3008+
def moment(rv, size, mu, kappa):
30093009
mu, _ = at.broadcast_arrays(mu, kappa)
30103010
if not rv_size_is_none(size):
30113011
mu = at.full(size, mu)
@@ -3105,7 +3105,7 @@ def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, *args, **kwargs):
31053105

31063106
return super().dist([mu, sigma, alpha], *args, **kwargs)
31073107

3108-
def get_moment(rv, size, mu, sigma, alpha):
3108+
def moment(rv, size, mu, sigma, alpha):
31093109
mean = mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha**2) ** 0.5
31103110
if not rv_size_is_none(size):
31113111
mean = at.full(size, mean)
@@ -3202,7 +3202,7 @@ def dist(cls, lower=0, upper=1, c=0.5, *args, **kwargs):
32023202

32033203
return super().dist([lower, c, upper], *args, **kwargs)
32043204

3205-
def get_moment(rv, size, lower, c, upper):
3205+
def moment(rv, size, lower, c, upper):
32063206
mean = (lower + upper + c) / 3
32073207
if not rv_size_is_none(size):
32083208
mean = at.full(size, mean)
@@ -3309,7 +3309,7 @@ def dist(
33093309

33103310
return super().dist([mu, beta], **kwargs)
33113311

3312-
def get_moment(rv, size, mu, beta):
3312+
def moment(rv, size, mu, beta):
33133313
mean = mu + beta * np.euler_gamma
33143314
if not rv_size_is_none(size):
33153315
mean = at.full(size, mean)
@@ -3439,7 +3439,7 @@ def get_nu_b(cls, nu, b, sigma):
34393439
return nu, b, sigma
34403440
raise ValueError("Rice distribution must specify either nu" " or b.")
34413441

3442-
def get_moment(rv, size, nu, sigma):
3442+
def moment(rv, size, nu, sigma):
34433443
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)
34443444
mean = (
34453445
sigma
@@ -3538,7 +3538,7 @@ def dist(cls, mu=0.0, s=1.0, *args, **kwargs):
35383538
s = at.as_tensor_variable(floatX(s))
35393539
return super().dist([mu, s], *args, **kwargs)
35403540

3541-
def get_moment(rv, size, mu, s):
3541+
def moment(rv, size, mu, s):
35423542
mu, _ = at.broadcast_arrays(mu, s)
35433543
if not rv_size_is_none(size):
35443544
mu = at.full(size, mu)
@@ -3643,7 +3643,7 @@ def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
36433643

36443644
return super().dist([mu, sigma], **kwargs)
36453645

3646-
def get_moment(rv, size, mu, sigma):
3646+
def moment(rv, size, mu, sigma):
36473647
median, _ = at.broadcast_arrays(invlogit(mu), sigma)
36483648
if not rv_size_is_none(size):
36493649
median = at.full(size, median)
@@ -3793,7 +3793,7 @@ def dist(cls, x_points, pdf_points, *args, **kwargs):
37933793

37943794
return super().dist([x_points, pdf_points, cdf_points], **kwargs)
37953795

3796-
def get_moment(rv, size, x_points, pdf_points, cdf_points):
3796+
def moment(rv, size, x_points, pdf_points, cdf_points):
37973797
# cdf_points argument is unused
37983798
moment = at.sum(at.mul(x_points, pdf_points))
37993799

@@ -3905,7 +3905,7 @@ def dist(cls, mu=0, sigma=1.0, *args, **kwargs):
39053905

39063906
return super().dist([mu, sigma], *args, **kwargs)
39073907

3908-
def get_moment(rv, size, mu, sigma):
3908+
def moment(rv, size, mu, sigma):
39093909
mean = mu + sigma * (np.euler_gamma + at.log(2))
39103910

39113911
if not rv_size_is_none(size):
@@ -4119,7 +4119,7 @@ def dist(cls, h=1.0, z=0.0, **kwargs):
41194119

41204120
return super().dist([h, z], **kwargs)
41214121

4122-
def get_moment(rv, size, h, z):
4122+
def moment(rv, size, h, z):
41234123
mean = at.switch(at.eq(z, 0), h / 4, tanh(z / 2) * (h / (2 * z)))
41244124
if not rv_size_is_none(size):
41254125
mean = at.full(size, mean)

0 commit comments

Comments
 (0)