Skip to content

Commit 06754f7

Browse files
committed
use expand_dims method
1 parent ebe8cd2 commit 06754f7

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pymc/testing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor.tensor.random.op import RandomVariable
3333
from scipy import special as sp
3434
from scipy import stats as st
35-
from xarray import DataArray
3635

3736
import pymc as pm
3837

@@ -1023,12 +1022,13 @@ def mock_pymc_sample():
10231022
draws=draws,
10241023
)
10251024

1026-
expanded_chains = DataArray(
1027-
np.ones(n_chains),
1028-
coords={"chain": np.arange(n_chains)},
1029-
)
10301025
idata.add_groups(
1031-
posterior=(idata["prior"].mean("chain") * expanded_chains).transpose("chain", "draw", ...),
1026+
posterior=(
1027+
idata["prior"]
1028+
.isel(chain=0)
1029+
.expand_dims({"chain": range(n_chains)})
1030+
.transpose("chain", "draw", ...)
1031+
)
10321032
)
10331033
del idata["prior"]
10341034
if "prior_predictive" in idata:

tests/test_testing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ def test_domain(values, edges, expectation):
3838

3939

4040
@pytest.mark.parametrize(
41-
"args, kwargs, expected_draws",
41+
"args, kwargs, expected_size",
4242
[
43-
pytest.param((), {}, 10, id="default"),
44-
pytest.param((100,), {}, 100, id="positional-draws"),
45-
pytest.param((), {"draws": 100}, 100, id="keyword-draws"),
43+
pytest.param((), {}, (1, 10), id="default"),
44+
pytest.param((100,), {}, (1, 100), id="positional-draws"),
45+
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
46+
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
4647
],
4748
)
48-
def test_mock_sample(args, kwargs, expected_draws) -> None:
49+
def test_mock_sample(args, kwargs, expected_size) -> None:
50+
expected_chains, expected_draws = expected_size
4951
_, model, _ = simple_normal(bounded_prior=True)
5052

5153
with model:
@@ -57,7 +59,7 @@ def test_mock_sample(args, kwargs, expected_draws) -> None:
5759
assert "posterior_predictive" not in idata
5860
assert "sample_stats" not in idata
5961

60-
assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws}
62+
assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}
6163

6264

6365
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)

0 commit comments

Comments
 (0)