Skip to content

Commit f9b749b

Browse files
authored
Rename pm.Constant to pm.DiracDelta (#5903)
1 parent 66215ac commit f9b749b

File tree

8 files changed

+66
-38
lines changed

8 files changed

+66
-38
lines changed

docs/source/api/distributions/discrete.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Discrete
1212
DiscreteWeibull
1313
Poisson
1414
NegativeBinomial
15-
Constant
15+
DiracDelta
1616
ZeroInflatedPoisson
1717
ZeroInflatedBinomial
1818
ZeroInflatedNegativeBinomial

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
Binomial,
6464
Categorical,
6565
Constant,
66+
DiracDelta,
6667
DiscreteUniform,
6768
DiscreteWeibull,
6869
Geometric,
@@ -140,6 +141,7 @@
140141
"Bernoulli",
141142
"Poisson",
142143
"NegativeBinomial",
144+
"DiracDelta",
143145
"Constant",
144146
"ZeroInflatedPoisson",
145147
"ZeroInflatedNegativeBinomial",

pymc/distributions/discrete.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"DiscreteWeibull",
5757
"Poisson",
5858
"NegativeBinomial",
59+
"DiracDelta",
5960
"Constant",
6061
"ZeroInflatedPoisson",
6162
"ZeroInflatedBinomial",
@@ -1337,11 +1338,11 @@ def logp(value, p):
13371338
)
13381339

13391340

1340-
class ConstantRV(RandomVariable):
1341-
name = "constant"
1341+
class DiracDeltaRV(RandomVariable):
1342+
name = "diracdelta"
13421343
ndim_supp = 0
13431344
ndims_params = [0]
1344-
_print_name = ("Constant", "\\operatorname{Constant}")
1345+
_print_name = ("DiracDelta", "\\operatorname{DiracDelta}")
13451346

13461347
def make_node(self, rng, size, dtype, c):
13471348
c = at.as_tensor_variable(c)
@@ -1354,22 +1355,22 @@ def rng_fn(cls, rng, c, size=None):
13541355
return np.full(size, c)
13551356

13561357

1357-
constant = ConstantRV()
1358+
diracdelta = DiracDeltaRV()
13581359

13591360

1360-
class Constant(Discrete):
1361+
class DiracDelta(Discrete):
13611362
r"""
1362-
Constant log-likelihood.
1363+
DiracDelta log-likelihood.
13631364
13641365
Parameters
13651366
----------
13661367
c: float or int
1367-
Constant parameter. The dtype of `c` determines the dtype of the distribution.
1368-
This can affect which sampler is assigned to Constant variables, or variables
1369-
that use Constant, such as Mixtures.
1368+
Dirac Delta parameter. The dtype of `c` determines the dtype of the distribution.
1369+
This can affect which sampler is assigned to DiracDelta variables, or variables
1370+
that use DiracDelta, such as Mixtures.
13701371
"""
13711372

1372-
rv_op = constant
1373+
rv_op = diracdelta
13731374

13741375
@classmethod
13751376
def dist(cls, c, *args, **kwargs):
@@ -1385,7 +1386,7 @@ def moment(rv, size, c):
13851386

13861387
def logp(value, c):
13871388
r"""
1388-
Calculate log-probability of Constant distribution at specified value.
1389+
Calculate log-probability of DiracDelta distribution at specified value.
13891390
13901391
Parameters
13911392
----------
@@ -1411,6 +1412,23 @@ def logcdf(value, c):
14111412
)
14121413

14131414

1415+
class Constant:
1416+
def __new__(cls, *args, **kwargs):
1417+
warnings.warn(
1418+
"pm.Constant has been deprecated. Use pm.DiracDelta instead.",
1419+
FutureWarning,
1420+
)
1421+
return DiracDelta(*args, **kwargs)
1422+
1423+
@classmethod
1424+
def dist(cls, *args, **kwargs):
1425+
warnings.warn(
1426+
"pm.Constant has been deprecated. Use pm.DiracDelta instead.",
1427+
FutureWarning,
1428+
)
1429+
return DiracDelta.dist(*args, **kwargs)
1430+
1431+
14141432
def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
14151433
"""Helper function to create a zero-inflated mixture
14161434
@@ -1419,7 +1437,7 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
14191437
nonzero_p = at.as_tensor_variable(floatX(nonzero_p))
14201438
weights = at.stack([1 - nonzero_p, nonzero_p], axis=-1)
14211439
comp_dists = [
1422-
Constant.dist(0),
1440+
DiracDelta.dist(0),
14231441
nonzero_dist,
14241442
]
14251443
if name is not None:

pymc/tests/test_distributions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def polyagamma_cdf(*args, **kwargs):
7171
Cauchy,
7272
ChiSquared,
7373
Constant,
74+
DiracDelta,
7475
Dirichlet,
7576
DirichletMultinomial,
7677
DiscreteUniform,
@@ -1729,9 +1730,9 @@ def test_poisson(self):
17291730
{"mu": Rplus},
17301731
)
17311732

1732-
def test_constantdist(self):
1733-
check_logp(Constant, I, {"c": I}, lambda value, c: np.log(c == value))
1734-
check_logcdf(Constant, I, {"c": I}, lambda value, c: np.log(value >= c))
1733+
def test_diracdeltadist(self):
1734+
check_logp(DiracDelta, I, {"c": I}, lambda value, c: np.log(c == value))
1735+
check_logcdf(DiracDelta, I, {"c": I}, lambda value, c: np.log(value >= c))
17351736

17361737
def test_zeroinflatedpoisson(self):
17371738
def logp_fn(value, psi, mu):
@@ -3065,7 +3066,7 @@ def test_issue_4499(self):
30653066
assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10)
30663067

30673068
with pm.Model(check_bounds=False) as m:
3068-
x = pm.Constant("x", 1, size=10)
3069+
x = pm.DiracDelta("x", 1, size=10)
30693070
assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)
30703071

30713072

@@ -3328,3 +3329,10 @@ def test_zero_inflated_dists_dtype_and_broadcast(dist, non_psi_args):
33283329
x = dist([0.5, 0.5, 0.5], *non_psi_args)
33293330
assert x.dtype in discrete_types
33303331
assert x.eval().shape == (3,)
3332+
3333+
3334+
def test_constantdist_deprecated():
3335+
with pytest.warns(FutureWarning, match="DiracDelta"):
3336+
with Model() as m:
3337+
x = Constant("x", c=1)
3338+
assert isinstance(x.owner.op, pm.distributions.discrete.DiracDeltaRV)

pymc/tests/test_distributions_moments.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
Categorical,
1919
Cauchy,
2020
ChiSquared,
21-
Constant,
2221
DensityDist,
22+
DiracDelta,
2323
Dirichlet,
2424
DirichletMultinomial,
2525
DiscreteUniform,
@@ -617,9 +617,9 @@ def test_negative_binomial_moment(n, p, size, expected):
617617
(np.arange(1, 6), None, np.arange(1, 6)),
618618
],
619619
)
620-
def test_constant_moment(c, size, expected):
620+
def test_diracdelta_moment(c, size, expected):
621621
with Model() as model:
622-
Constant("x", c=c, size=size)
622+
DiracDelta("x", c=c, size=size)
623623
assert_moment_is_expected(model, expected)
624624

625625

pymc/tests/test_distributions_random.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,17 +1501,17 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
15011501
]
15021502

15031503

1504-
class TestConstant(BaseTestDistributionRandom):
1505-
def constant_rng_fn(self, size, c):
1504+
class TestDiracDelta(BaseTestDistributionRandom):
1505+
def diracdelta_rng_fn(self, size, c):
15061506
if size is None:
15071507
return c
15081508
return np.full(size, c)
15091509

1510-
pymc_dist = pm.Constant
1510+
pymc_dist = pm.DiracDelta
15111511
pymc_dist_params = {"c": 3}
15121512
expected_rv_op_params = {"c": 3}
15131513
reference_dist_params = {"c": 3}
1514-
reference_dist = lambda self: self.constant_rng_fn
1514+
reference_dist = lambda self: self.diracdelta_rng_fn
15151515
checks_to_run = [
15161516
"check_pymc_params_match_rv_op",
15171517
"check_pymc_draws_match_reference",
@@ -1524,10 +1524,10 @@ def constant_rng_fn(self, size, c):
15241524
)
15251525
def test_dtype(self, floatX):
15261526
with aesara.config.change_flags(floatX=floatX):
1527-
assert pm.Constant.dist(2**4).dtype == "int8"
1528-
assert pm.Constant.dist(2**16).dtype == "int32"
1529-
assert pm.Constant.dist(2**32).dtype == "int64"
1530-
assert pm.Constant.dist(2.0).dtype == floatX
1527+
assert pm.DiracDelta.dist(2**4).dtype == "int8"
1528+
assert pm.DiracDelta.dist(2**16).dtype == "int32"
1529+
assert pm.DiracDelta.dist(2**32).dtype == "int64"
1530+
assert pm.DiracDelta.dist(2.0).dtype == floatX
15311531

15321532

15331533
class TestOrderedLogistic(BaseTestDistributionRandom):
@@ -1860,8 +1860,8 @@ def ref_rand(size, n, eta):
18601860

18611861
class TestLKJCholeskyCov(BaseTestDistributionRandom):
18621862
pymc_dist = _LKJCholeskyCov
1863-
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
1864-
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
1863+
pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
1864+
expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
18651865
size = None
18661866

18671867
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
@@ -1891,7 +1891,7 @@ def check_rv_size(self):
18911891
def check_draws_match_expected(self):
18921892
# TODO: Find better comparison:
18931893
rng = aesara.shared(self.get_random_state(reset=True))
1894-
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.Constant.dist([0.5, 2.0]), rng=rng)
1894+
x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.DiracDelta.dist([0.5, 2.0]), rng=rng)
18951895
assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)
18961896

18971897

pymc/tests/test_distributions_timeseries.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from pymc.aesaraf import floatX
2424
from pymc.distributions.continuous import Flat, HalfNormal, Normal
25-
from pymc.distributions.discrete import Constant
25+
from pymc.distributions.discrete import DiracDelta
2626
from pymc.distributions.logprob import logp
2727
from pymc.distributions.multivariate import Dirichlet
2828
from pymc.distributions.timeseries import (
@@ -100,11 +100,11 @@ class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
100100
size = None
101101

102102
pymc_dist = pm.GaussianRandomWalk
103-
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.Constant.dist(0), "steps": 4}
103+
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init_dist": pm.DiracDelta.dist(0), "steps": 4}
104104
expected_rv_op_params = {
105105
"mu": 1.0,
106106
"sigma": 2,
107-
"init_dist": pm.Constant.dist(0),
107+
"init_dist": pm.DiracDelta.dist(0),
108108
"steps": 4,
109109
}
110110

@@ -455,7 +455,7 @@ def test_multivariate_init_dist(self):
455455
)
456456
def test_moment(self, size, expected):
457457
with Model() as model:
458-
init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]])
458+
init_dist = DiracDelta.dist([[1.0, 2.0], [3.0, 4.0]])
459459
AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size)
460460
assert_moment_is_expected(model, expected, check_finite_logp=False)
461461

pymc/tests/test_sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def test_issue_4490(self):
13141314
def test_aesara_function_kwargs(self):
13151315
sharedvar = aesara.shared(0)
13161316
with pm.Model() as m:
1317-
x = pm.Constant("x", 0)
1317+
x = pm.DiracDelta("x", 0)
13181318
y = pm.Deterministic("y", x + sharedvar)
13191319

13201320
prior = pm.sample_prior_predictive(
@@ -1361,7 +1361,7 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
13611361
def test_aesara_function_kwargs(self):
13621362
sharedvar = aesara.shared(0)
13631363
with pm.Model() as m:
1364-
x = pm.Constant("x", 0.0)
1364+
x = pm.DiracDelta("x", 0.0)
13651365
y = pm.Deterministic("y", x + sharedvar)
13661366

13671367
pp = pm.sample_posterior_predictive(
@@ -1434,7 +1434,7 @@ def test_draw_different_samples(self):
14341434

14351435
def test_draw_aesara_function_kwargs(self):
14361436
sharedvar = aesara.shared(0)
1437-
x = pm.Constant.dist(0.0)
1437+
x = pm.DiracDelta.dist(0.0)
14381438
y = x + sharedvar
14391439
draws = pm.draw(
14401440
y,

0 commit comments

Comments
 (0)