@@ -1571,112 +1571,14 @@ def constant_rng_fn(self, size, c):
1571
1571
"check_pymc_params_match_rv_op" ,
1572
1572
"check_pymc_draws_match_reference" ,
1573
1573
"check_rv_size" ,
1574
+ "check_dtype" ,
1574
1575
]
1575
1576
1576
-
1577
- class TestZeroInflatedPoisson (BaseTestDistributionRandom ):
1578
- def zero_inflated_poisson_rng_fn (self , size , psi , theta , poisson_rng_fct , random_rng_fct ):
1579
- return poisson_rng_fct (theta , size = size ) * (random_rng_fct (size = size ) < psi )
1580
-
1581
- def seeded_zero_inflated_poisson_rng_fn (self ):
1582
- poisson_rng_fct = functools .partial (
1583
- getattr (np .random .RandomState , "poisson" ), self .get_random_state ()
1584
- )
1585
-
1586
- random_rng_fct = functools .partial (
1587
- getattr (np .random .RandomState , "random" ), self .get_random_state ()
1588
- )
1589
-
1590
- return functools .partial (
1591
- self .zero_inflated_poisson_rng_fn ,
1592
- poisson_rng_fct = poisson_rng_fct ,
1593
- random_rng_fct = random_rng_fct ,
1594
- )
1595
-
1596
- pymc_dist = pm .ZeroInflatedPoisson
1597
- pymc_dist_params = {"psi" : 0.9 , "theta" : 4.0 }
1598
- expected_rv_op_params = {"psi" : 0.9 , "theta" : 4.0 }
1599
- reference_dist_params = {"psi" : 0.9 , "theta" : 4.0 }
1600
- reference_dist = seeded_zero_inflated_poisson_rng_fn
1601
- checks_to_run = [
1602
- "check_pymc_params_match_rv_op" ,
1603
- "check_pymc_draws_match_reference" ,
1604
- "check_rv_size" ,
1605
- ]
1606
-
1607
-
1608
- class TestZeroInflatedBinomial (BaseTestDistributionRandom ):
1609
- def zero_inflated_binomial_rng_fn (self , size , psi , n , p , binomial_rng_fct , random_rng_fct ):
1610
- return binomial_rng_fct (n , p , size = size ) * (random_rng_fct (size = size ) < psi )
1611
-
1612
- def seeded_zero_inflated_binomial_rng_fn (self ):
1613
- binomial_rng_fct = functools .partial (
1614
- getattr (np .random .RandomState , "binomial" ), self .get_random_state ()
1615
- )
1616
-
1617
- random_rng_fct = functools .partial (
1618
- getattr (np .random .RandomState , "random" ), self .get_random_state ()
1619
- )
1620
-
1621
- return functools .partial (
1622
- self .zero_inflated_binomial_rng_fn ,
1623
- binomial_rng_fct = binomial_rng_fct ,
1624
- random_rng_fct = random_rng_fct ,
1625
- )
1626
-
1627
- pymc_dist = pm .ZeroInflatedBinomial
1628
- pymc_dist_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1629
- expected_rv_op_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1630
- reference_dist_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1631
- reference_dist = seeded_zero_inflated_binomial_rng_fn
1632
- checks_to_run = [
1633
- "check_pymc_params_match_rv_op" ,
1634
- "check_pymc_draws_match_reference" ,
1635
- "check_rv_size" ,
1636
- ]
1637
-
1638
-
1639
- class TestZeroInflatedNegativeBinomialMuSigma (BaseTestDistributionRandom ):
1640
- def zero_inflated_negbinomial_rng_fn (
1641
- self , size , psi , n , p , negbinomial_rng_fct , random_rng_fct
1642
- ):
1643
- return negbinomial_rng_fct (n , p , size = size ) * (random_rng_fct (size = size ) < psi )
1644
-
1645
- def seeded_zero_inflated_negbinomial_rng_fn (self ):
1646
- negbinomial_rng_fct = functools .partial (
1647
- getattr (np .random .RandomState , "negative_binomial" ), self .get_random_state ()
1648
- )
1649
-
1650
- random_rng_fct = functools .partial (
1651
- getattr (np .random .RandomState , "random" ), self .get_random_state ()
1652
- )
1653
-
1654
- return functools .partial (
1655
- self .zero_inflated_negbinomial_rng_fn ,
1656
- negbinomial_rng_fct = negbinomial_rng_fct ,
1657
- random_rng_fct = random_rng_fct ,
1658
- )
1659
-
1660
- n , p = pm .NegativeBinomial .get_n_p (mu = 3 , alpha = 5 )
1661
-
1662
- pymc_dist = pm .ZeroInflatedNegativeBinomial
1663
- pymc_dist_params = {"psi" : 0.9 , "mu" : 3 , "alpha" : 5 }
1664
- expected_rv_op_params = {"psi" : 0.9 , "n" : n , "p" : p }
1665
- reference_dist_params = {"psi" : 0.9 , "n" : n , "p" : p }
1666
- reference_dist = seeded_zero_inflated_negbinomial_rng_fn
1667
- checks_to_run = [
1668
- "check_pymc_params_match_rv_op" ,
1669
- "check_pymc_draws_match_reference" ,
1670
- "check_rv_size" ,
1671
- ]
1672
-
1673
-
1674
- class TestZeroInflatedNegativeBinomial (BaseTestDistributionRandom ):
1675
- pymc_dist = pm .ZeroInflatedNegativeBinomial
1676
- pymc_dist_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1677
- expected_rv_op_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1678
- reference_dist_params = {"psi" : 0.9 , "n" : 12 , "p" : 0.7 }
1679
- checks_to_run = ["check_pymc_params_match_rv_op" ]
1577
+ def check_dtype (self ):
1578
+ assert pm .Constant .dist (2 ** 4 ).dtype == "int8"
1579
+ assert pm .Constant .dist (2 ** 16 ).dtype == "int32"
1580
+ assert pm .Constant .dist (2 ** 32 ).dtype == "int64"
1581
+ assert pm .Constant .dist (2.0 ).dtype == aesara .config .floatX
1680
1582
1681
1583
1682
1584
class TestOrderedLogistic (BaseTestDistributionRandom ):
0 commit comments