2020import numpy as np
2121import pytensor
2222import pytensor .tensor as pt
23+ import xarray as xr
2324
2425from arviz import InferenceData
2526from numpy import random as nr
2627from numpy import testing as npt
28+ from numpy .typing import NDArray
2729from pytensor .compile import SharedVariable
2830from pytensor .compile .mode import Mode
2931from pytensor .graph .basic import Constant , Variable , equal_computations , graph_inputs
@@ -976,7 +978,14 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
976978 raise AssertionError (f"RV found in graph: { rvs } " )
977979
978980
979- def mock_sample (draws : int = 10 , ** kwargs ):
981+ SampleStatsCreator = Callable [[tuple [str , ...]], NDArray ]
982+
983+
984+ def mock_sample (
985+ draws : int = 10 ,
986+ sample_stats : dict [str , SampleStatsCreator ] | None = None ,
987+ ** kwargs ,
988+ ) -> InferenceData :
980989 """Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.
981990
982991 Useful for testing models that use pm.sample without running MCMC sampling.
@@ -1006,6 +1015,36 @@ def mock_pymc_sample():
10061015
10071016 pm.sample = original_sample
10081017
1018+ By default, the sample_stats group is not created. Pass a dictionary of functions
1019+ that create sample statistics, where the keys are the names of the statistics
1020+ and the values are functions that take a size tuple and return an array of that size.
1021+
1022+ .. code-block:: python
1023+
1024+ from functools import partial
1025+
1026+ import numpy as np
1027+ import numpy.typing as npt
1028+
1029+ from pymc.testing import mock_sample
1030+
1031+
1032+ def mock_diverging(size: tuple[str, ...]) -> npt.NDArray:
1033+ return np.zeros(size)
1034+
1035+
1036+ def mock_tree_depth(size: tuple[str, ...]) -> npt.NDArray:
1037+ return np.random.choice(range(2, 10), size=size)
1038+
1039+
1040+ mock_sample_with_stats = partial(
1041+ mock_sample,
1042+ sample_stats={
1043+ "diverging": mock_diverging,
1044+ "tree_depth": mock_tree_depth,
1045+ },
1046+ )
1047+
10091048 """
10101049 random_seed = kwargs .get ("random_seed" , None )
10111050 model = kwargs .get ("model" , None )
@@ -1028,6 +1067,16 @@ def mock_pymc_sample():
10281067 del idata ["prior" ]
10291068 if "prior_predictive" in idata :
10301069 del idata ["prior_predictive" ]
1070+
1071+ if sample_stats is not None :
1072+ sizes = idata ["posterior" ].sizes
1073+ size = (sizes ["chain" ], sizes ["draw" ])
1074+ sample_stats_ds = xr .Dataset (
1075+ {name : (("chain" , "draw" ), creator (size )) for name , creator in sample_stats .items ()},
1076+ coords = idata ["posterior" ].coords ,
1077+ )
1078+ idata .add_groups (sample_stats = sample_stats_ds )
1079+
10311080 return idata
10321081
10331082
0 commit comments