Skip to content

Commit edb832c

Browse files
committed
Allow Constant to be integer or float
1 parent 8cb71a5 commit edb832c

File tree

4 files changed

+44
-112
lines changed

4 files changed

+44
-112
lines changed

pymc/distributions/discrete.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pymc.distributions.logprob import logcdf, logp
4747
from pymc.distributions.shape_utils import rv_size_is_none
4848
from pymc.math import sigmoid
49+
from pymc.vartypes import continuous_types
4950

5051
__all__ = [
5152
"Binomial",
@@ -1315,9 +1316,12 @@ class ConstantRV(RandomVariable):
13151316
name = "constant"
13161317
ndim_supp = 0
13171318
ndims_params = [0]
1318-
dtype = "floatX" # Should be treated as a discrete variable!
13191319
_print_name = ("Constant", "\\operatorname{Constant}")
13201320

1321+
def make_node(self, rng, size, dtype, c):
1322+
c = at.as_tensor_variable(c)
1323+
return super().make_node(rng, size, c.dtype, c)
1324+
13211325
@classmethod
13221326
def rng_fn(cls, rng, c, size=None):
13231327
if size is None:
@@ -1334,15 +1338,19 @@ class Constant(Discrete):
13341338
13351339
Parameters
13361340
----------
1337-
value: float or int
1338-
Constant parameter.
1341+
c: float or int
1342+
Constant parameter. The dtype of `c` determines the dtype of the distribution.
1343+
This can affect which sampler is assigned to Constant variables, or variables
1344+
that use Constant, such as Mixtures.
13391345
"""
13401346

13411347
rv_op = constant
13421348

13431349
@classmethod
13441350
def dist(cls, c, *args, **kwargs):
1345-
c = at.as_tensor_variable(floatX(c))
1351+
c = at.as_tensor_variable(c)
1352+
if c.dtype in continuous_types:
1353+
c = floatX(c)
13461354
return super().dist([c], **kwargs)
13471355

13481356
def get_moment(rv, size, c):

pymc/tests/test_distributions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def polyagamma_cdf(*args, **kwargs):
130130
from pymc.math import kronecker
131131
from pymc.model import Deterministic, Model, Point, Potential
132132
from pymc.tests.helpers import select_by_precision
133-
from pymc.vartypes import continuous_types
133+
from pymc.vartypes import continuous_types, discrete_types
134134

135135

136136
def get_lkj_cases():
@@ -934,7 +934,9 @@ def check_selfconsistency_discrete_logcdf(
934934
Check that logcdf of discrete distributions matches sum of logps up to value
935935
"""
936936
# This test only works for scalar random variables
937-
assert distribution.rv_op.ndim_supp == 0
937+
rv_op = getattr(distribution, "rv_op", None)
938+
if rv_op:
939+
assert rv_op.ndim_supp == 0
938940

939941
domains = paramdomains.copy()
940942
domains["value"] = domain
@@ -3416,3 +3418,17 @@ def test_sd_dist_automatically_resized(self, sd_dist):
34163418
assert resized_sd_dist.eval().shape == (10, 3)
34173419
# LKJCov has support shape `(n * (n+1)) // 2`
34183420
assert x.eval().shape == (10, 6)
3421+
3422+
3423+
@pytest.mark.parametrize(
3424+
"dist, non_psi_args",
3425+
[
3426+
(pm.ZeroInflatedPoisson.dist, (2,)),
3427+
(pm.ZeroInflatedBinomial.dist, (2, 0.5)),
3428+
(pm.ZeroInflatedNegativeBinomial.dist, (2, 2)),
3429+
],
3430+
)
3431+
def test_zero_inflated_dists_dtype_and_broadcast(dist, non_psi_args):
3432+
x = dist([0.5, 0.5, 0.5], *non_psi_args)
3433+
assert x.dtype in discrete_types
3434+
assert x.eval().shape == (3,)

pymc/tests/test_distributions_moments.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def test_constant_moment(c, size, expected):
626626
@pytest.mark.parametrize(
627627
"psi, theta, size, expected",
628628
[
629-
(0.9, 3.0, None, 2),
629+
(0.9, 3.0, None, 3),
630630
(0.8, 2.9, 5, np.full(5, 2)),
631631
(0.2, np.arange(1, 5) * 5, None, np.arange(1, 5)),
632632
(0.2, np.arange(1, 5) * 5, (2, 4), np.full((2, 4), np.arange(1, 5))),
@@ -1335,7 +1335,13 @@ def test_multinomial_moment(p, n, size, expected):
13351335
[
13361336
(0.2, 10, 3, None, 2),
13371337
(0.2, 10, 4, 5, np.full(5, 2)),
1338-
(0.4, np.arange(1, 5), np.arange(2, 6), None, np.array([0, 0, 1, 1])),
1338+
(
1339+
0.4,
1340+
np.arange(1, 5),
1341+
np.arange(2, 6),
1342+
None,
1343+
np.array([0, 1, 1, 2] if aesara.config.floatX == "float64" else [0, 0, 1, 1]),
1344+
),
13391345
(
13401346
np.linspace(0.2, 0.6, 3),
13411347
np.arange(1, 10, 4),

pymc/tests/test_distributions_random.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,112 +1571,14 @@ def constant_rng_fn(self, size, c):
15711571
"check_pymc_params_match_rv_op",
15721572
"check_pymc_draws_match_reference",
15731573
"check_rv_size",
1574+
"check_dtype",
15741575
]
15751576

1576-
1577-
class TestZeroInflatedPoisson(BaseTestDistributionRandom):
1578-
def zero_inflated_poisson_rng_fn(self, size, psi, theta, poisson_rng_fct, random_rng_fct):
1579-
return poisson_rng_fct(theta, size=size) * (random_rng_fct(size=size) < psi)
1580-
1581-
def seeded_zero_inflated_poisson_rng_fn(self):
1582-
poisson_rng_fct = functools.partial(
1583-
getattr(np.random.RandomState, "poisson"), self.get_random_state()
1584-
)
1585-
1586-
random_rng_fct = functools.partial(
1587-
getattr(np.random.RandomState, "random"), self.get_random_state()
1588-
)
1589-
1590-
return functools.partial(
1591-
self.zero_inflated_poisson_rng_fn,
1592-
poisson_rng_fct=poisson_rng_fct,
1593-
random_rng_fct=random_rng_fct,
1594-
)
1595-
1596-
pymc_dist = pm.ZeroInflatedPoisson
1597-
pymc_dist_params = {"psi": 0.9, "theta": 4.0}
1598-
expected_rv_op_params = {"psi": 0.9, "theta": 4.0}
1599-
reference_dist_params = {"psi": 0.9, "theta": 4.0}
1600-
reference_dist = seeded_zero_inflated_poisson_rng_fn
1601-
checks_to_run = [
1602-
"check_pymc_params_match_rv_op",
1603-
"check_pymc_draws_match_reference",
1604-
"check_rv_size",
1605-
]
1606-
1607-
1608-
class TestZeroInflatedBinomial(BaseTestDistributionRandom):
1609-
def zero_inflated_binomial_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
1610-
return binomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
1611-
1612-
def seeded_zero_inflated_binomial_rng_fn(self):
1613-
binomial_rng_fct = functools.partial(
1614-
getattr(np.random.RandomState, "binomial"), self.get_random_state()
1615-
)
1616-
1617-
random_rng_fct = functools.partial(
1618-
getattr(np.random.RandomState, "random"), self.get_random_state()
1619-
)
1620-
1621-
return functools.partial(
1622-
self.zero_inflated_binomial_rng_fn,
1623-
binomial_rng_fct=binomial_rng_fct,
1624-
random_rng_fct=random_rng_fct,
1625-
)
1626-
1627-
pymc_dist = pm.ZeroInflatedBinomial
1628-
pymc_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1629-
expected_rv_op_params = {"psi": 0.9, "n": 12, "p": 0.7}
1630-
reference_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1631-
reference_dist = seeded_zero_inflated_binomial_rng_fn
1632-
checks_to_run = [
1633-
"check_pymc_params_match_rv_op",
1634-
"check_pymc_draws_match_reference",
1635-
"check_rv_size",
1636-
]
1637-
1638-
1639-
class TestZeroInflatedNegativeBinomialMuSigma(BaseTestDistributionRandom):
1640-
def zero_inflated_negbinomial_rng_fn(
1641-
self, size, psi, n, p, negbinomial_rng_fct, random_rng_fct
1642-
):
1643-
return negbinomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
1644-
1645-
def seeded_zero_inflated_negbinomial_rng_fn(self):
1646-
negbinomial_rng_fct = functools.partial(
1647-
getattr(np.random.RandomState, "negative_binomial"), self.get_random_state()
1648-
)
1649-
1650-
random_rng_fct = functools.partial(
1651-
getattr(np.random.RandomState, "random"), self.get_random_state()
1652-
)
1653-
1654-
return functools.partial(
1655-
self.zero_inflated_negbinomial_rng_fn,
1656-
negbinomial_rng_fct=negbinomial_rng_fct,
1657-
random_rng_fct=random_rng_fct,
1658-
)
1659-
1660-
n, p = pm.NegativeBinomial.get_n_p(mu=3, alpha=5)
1661-
1662-
pymc_dist = pm.ZeroInflatedNegativeBinomial
1663-
pymc_dist_params = {"psi": 0.9, "mu": 3, "alpha": 5}
1664-
expected_rv_op_params = {"psi": 0.9, "n": n, "p": p}
1665-
reference_dist_params = {"psi": 0.9, "n": n, "p": p}
1666-
reference_dist = seeded_zero_inflated_negbinomial_rng_fn
1667-
checks_to_run = [
1668-
"check_pymc_params_match_rv_op",
1669-
"check_pymc_draws_match_reference",
1670-
"check_rv_size",
1671-
]
1672-
1673-
1674-
class TestZeroInflatedNegativeBinomial(BaseTestDistributionRandom):
1675-
pymc_dist = pm.ZeroInflatedNegativeBinomial
1676-
pymc_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1677-
expected_rv_op_params = {"psi": 0.9, "n": 12, "p": 0.7}
1678-
reference_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
1679-
checks_to_run = ["check_pymc_params_match_rv_op"]
1577+
def check_dtype(self):
1578+
assert pm.Constant.dist(2**4).dtype == "int8"
1579+
assert pm.Constant.dist(2**16).dtype == "int32"
1580+
assert pm.Constant.dist(2**32).dtype == "int64"
1581+
assert pm.Constant.dist(2.0).dtype == aesara.config.floatX
16801582

16811583

16821584
class TestOrderedLogistic(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)