|
18 | 18 | from collections.abc import Callable, Sequence |
19 | 19 | from typing import Any |
20 | 20 |
|
| 21 | +from arviz import InferenceData |
| 22 | +from xarray import DataArray |
21 | 23 | import numpy as np |
22 | 24 | import pytensor |
23 | 25 | import pytensor.tensor as pt |
@@ -982,3 +984,53 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: |
982 | 984 | rvs = rvs_in_graph(vars) |
983 | 985 | if rvs: |
984 | 986 | raise AssertionError(f"RV found in graph: {rvs}") |
| 987 | + |
| 988 | + |
| 989 | +def mock_sample(*args, **kwargs): |
| 990 | + """Mock the pm.sample function by returning prior predictive samples as posterior. |
| 991 | +
|
| 992 | + Useful for testing models that use pm.sample without running MCMC sampling. |
| 993 | +
|
| 994 | + Examples |
| 995 | + -------- |
| 996 | + Using mock_sample with pytest |
| 997 | +
|
| 998 | + .. code-block:: python |
| 999 | +
|
| 1000 | + import pytest |
| 1001 | +
|
| 1002 | + import pymc as pm |
| 1003 | + from pymc.testing import mock_sample |
| 1004 | +
|
| 1005 | +
|
| 1006 | + @pytest.fixture(scope="module") |
| 1007 | + def mock_pymc_sample(): |
| 1008 | + original_sample = pm.sample |
| 1009 | + pm.sample = mock_sample |
| 1010 | +
|
| 1011 | + yield |
| 1012 | +
|
| 1013 | + pm.sample = original_sample |
| 1014 | +
|
| 1015 | + """ |
| 1016 | + random_seed = kwargs.get("random_seed", None) |
| 1017 | + model = kwargs.get("model", None) |
| 1018 | + draws = kwargs.get("draws", 10) |
| 1019 | + n_chains = kwargs.get("chains", 1) |
| 1020 | + idata: InferenceData = pm.sample_prior_predictive( |
| 1021 | + model=model, |
| 1022 | + random_seed=random_seed, |
| 1023 | + draws=draws, |
| 1024 | + ) |
| 1025 | + |
| 1026 | + expanded_chains = DataArray( |
| 1027 | + np.ones(n_chains), |
| 1028 | + coords={"chain": np.arange(n_chains)}, |
| 1029 | + ) |
| 1030 | + idata.add_groups( |
| 1031 | + posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...) |
| 1032 | + ) |
| 1033 | + del idata.prior |
| 1034 | + if "prior_predictive" in idata: |
| 1035 | + del idata.prior_predictive |
| 1036 | + return idata |
0 commit comments