Skip to content

Commit d7cedd7

Browse files
committed
add kwarg to allow for custom sample_stats
1 parent 0960323 commit d7cedd7

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

pymc/testing.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import numpy as np
2121
import pytensor
2222
import pytensor.tensor as pt
23+
import xarray as xr
2324

2425
from arviz import InferenceData
2526
from numpy import random as nr
2627
from numpy import testing as npt
28+
from numpy.typing import NDArray
2729
from pytensor.compile import SharedVariable
2830
from pytensor.compile.mode import Mode
2931
from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs
@@ -976,7 +978,14 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
976978
raise AssertionError(f"RV found in graph: {rvs}")
977979

978980

979-
def mock_sample(draws: int = 10, **kwargs):
981+
SampleStatsCreator = Callable[[tuple[str, ...]], NDArray]
982+
983+
984+
def mock_sample(
985+
draws: int = 10,
986+
sample_stats: dict[str, SampleStatsCreator] | None = None,
987+
**kwargs,
988+
) -> InferenceData:
980989
"""Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.
981990
982991
Useful for testing models that use pm.sample without running MCMC sampling.
@@ -1006,6 +1015,36 @@ def mock_pymc_sample():
10061015
10071016
pm.sample = original_sample
10081017
1018+
By default, the sample_stats group is not created. Pass a dictionary of functions
1019+
that create sample statistics, where the keys are the names of the statistics
1020+
and the values are functions that take a size tuple and return an array of that size.
1021+
1022+
.. code-block:: python
1023+
1024+
from functools import partial
1025+
1026+
import numpy as np
1027+
import numpy.typing as npt
1028+
1029+
from pymc.testing import mock_sample
1030+
1031+
1032+
def mock_diverging(size: tuple[str, ...]) -> npt.NDArray:
1033+
return np.zeros(size)
1034+
1035+
1036+
def mock_tree_depth(size: tuple[str, ...]) -> npt.NDArray:
1037+
return np.random.choice(range(2, 10), size=size)
1038+
1039+
1040+
mock_sample_with_stats = partial(
1041+
mock_sample,
1042+
sample_stats={
1043+
"diverging": mock_diverging,
1044+
"tree_depth": mock_tree_depth,
1045+
},
1046+
)
1047+
10091048
"""
10101049
random_seed = kwargs.get("random_seed", None)
10111050
model = kwargs.get("model", None)
@@ -1028,6 +1067,16 @@ def mock_pymc_sample():
10281067
del idata["prior"]
10291068
if "prior_predictive" in idata:
10301069
del idata["prior_predictive"]
1070+
1071+
if sample_stats is not None:
1072+
sizes = idata["posterior"].sizes
1073+
size = (sizes["chain"], sizes["draw"])
1074+
sample_stats_ds = xr.Dataset(
1075+
{name: (("chain", "draw"), creator(size)) for name, creator in sample_stats.items()},
1076+
coords=idata["posterior"].coords,
1077+
)
1078+
idata.add_groups(sample_stats=sample_stats_ds)
1079+
10311080
return idata
10321081

10331082

0 commit comments

Comments
 (0)