Skip to content

Commit f7de7f4

Browse files
ricardoV94michaelosthege
authored andcommitted
Move stickbreaking tests
These tests were introduced between pre-existing Dirichlet/Multinomial/DirichletMultionmial tests that belong to conceptually related distributions
1 parent a3bab7d commit f7de7f4

File tree

2 files changed

+69
-69
lines changed

2 files changed

+69
-69
lines changed

pymc/tests/test_distributions.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,40 +2122,6 @@ def test_dirichlet_invalid(self):
21222122
valid_dist = Dirichlet.dist(a=[1, 1, 1])
21232123
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
21242124

2125-
@pytest.mark.parametrize(
2126-
"value,alpha,K,logp",
2127-
[
2128-
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
2129-
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
2130-
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
2131-
(np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092),
2132-
(
2133-
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2134-
2.5,
2135-
3,
2136-
np.array([1.29317672, 1.50126157]),
2137-
),
2138-
],
2139-
)
2140-
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
2141-
with Model() as model:
2142-
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
2143-
pt = {"sbw": value}
2144-
assert_almost_equal(
2145-
pm.logp(sbw, value).eval(),
2146-
logp,
2147-
decimal=select_by_precision(float64=6, float32=2),
2148-
err_msg=str(pt),
2149-
)
2150-
2151-
def test_stickbreakingweights_invalid(self):
2152-
sbw = pm.StickBreakingWeights.dist(3.0, 3)
2153-
sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
2154-
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
2155-
assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
2156-
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
2157-
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
2158-
21592125
@pytest.mark.parametrize(
21602126
"a",
21612127
[
@@ -2318,6 +2284,40 @@ def test_dirichlet_multinomial_vectorized(self, n, a, size):
23182284
err_msg=f"vals={vals}",
23192285
)
23202286

2287+
@pytest.mark.parametrize(
2288+
"value,alpha,K,logp",
2289+
[
2290+
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
2291+
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
2292+
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
2293+
(np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092),
2294+
(
2295+
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2296+
2.5,
2297+
3,
2298+
np.array([1.29317672, 1.50126157]),
2299+
),
2300+
],
2301+
)
2302+
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
2303+
with Model() as model:
2304+
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
2305+
pt = {"sbw": value}
2306+
assert_almost_equal(
2307+
pm.logp(sbw, value).eval(),
2308+
logp,
2309+
decimal=select_by_precision(float64=6, float32=2),
2310+
err_msg=str(pt),
2311+
)
2312+
2313+
def test_stickbreakingweights_invalid(self):
2314+
sbw = pm.StickBreakingWeights.dist(3.0, 3)
2315+
sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
2316+
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
2317+
assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
2318+
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
2319+
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
2320+
23212321
@aesara.config.change_flags(compute_test_value="raise")
23222322
def test_categorical_bounds(self):
23232323
with Model():

pymc/tests/test_distributions_random.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,41 +1309,6 @@ class TestDirichlet(BaseTestDistributionRandom):
13091309
]
13101310

13111311

1312-
class TestStickBreakingWeights(BaseTestDistributionRandom):
1313-
pymc_dist = pm.StickBreakingWeights
1314-
pymc_dist_params = {"alpha": 2.0, "K": 19}
1315-
expected_rv_op_params = {"alpha": 2.0, "K": 19}
1316-
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
1317-
sizes_expected = [
1318-
(20,),
1319-
(17, 20),
1320-
(
1321-
5,
1322-
20,
1323-
),
1324-
(11, 5, 20),
1325-
(3, 13, 5, 20),
1326-
]
1327-
checks_to_run = [
1328-
"check_pymc_params_match_rv_op",
1329-
"check_rv_size",
1330-
"check_basic_properties",
1331-
]
1332-
1333-
def check_basic_properties(self):
1334-
default_rng = aesara.shared(np.random.default_rng(1234))
1335-
draws = pm.StickBreakingWeights.dist(
1336-
alpha=3.5,
1337-
K=19,
1338-
size=(2, 3, 5),
1339-
rng=default_rng,
1340-
).eval()
1341-
1342-
assert np.allclose(draws.sum(-1), 1)
1343-
assert np.all(draws >= 0)
1344-
assert np.all(draws <= 1)
1345-
1346-
13471312
class TestMultinomial(BaseTestDistributionRandom):
13481313
pymc_dist = pm.Multinomial
13491314
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
@@ -1400,6 +1365,41 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom):
14001365
checks_to_run = ["check_rv_size"]
14011366

14021367

1368+
class TestStickBreakingWeights(BaseTestDistributionRandom):
1369+
pymc_dist = pm.StickBreakingWeights
1370+
pymc_dist_params = {"alpha": 2.0, "K": 19}
1371+
expected_rv_op_params = {"alpha": 2.0, "K": 19}
1372+
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
1373+
sizes_expected = [
1374+
(20,),
1375+
(17, 20),
1376+
(
1377+
5,
1378+
20,
1379+
),
1380+
(11, 5, 20),
1381+
(3, 13, 5, 20),
1382+
]
1383+
checks_to_run = [
1384+
"check_pymc_params_match_rv_op",
1385+
"check_rv_size",
1386+
"check_basic_properties",
1387+
]
1388+
1389+
def check_basic_properties(self):
1390+
default_rng = aesara.shared(np.random.default_rng(1234))
1391+
draws = pm.StickBreakingWeights.dist(
1392+
alpha=3.5,
1393+
K=19,
1394+
size=(2, 3, 5),
1395+
rng=default_rng,
1396+
).eval()
1397+
1398+
assert np.allclose(draws.sum(-1), 1)
1399+
assert np.all(draws >= 0)
1400+
assert np.all(draws <= 1)
1401+
1402+
14031403
class TestCategorical(BaseTestDistributionRandom):
14041404
pymc_dist = pm.Categorical
14051405
pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}

0 commit comments

Comments
 (0)