Skip to content

Commit 30dedf2

Browse files
committed
push up pymc-marketing mock
1 parent d34ed95 commit 30dedf2

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

pymc/testing.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from collections.abc import Callable, Sequence
1919
from typing import Any
2020

21+
from arviz import InferenceData
22+
from xarray import DataArray
2123
import numpy as np
2224
import pytensor
2325
import pytensor.tensor as pt
@@ -982,3 +984,53 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
982984
rvs = rvs_in_graph(vars)
983985
if rvs:
984986
raise AssertionError(f"RV found in graph: {rvs}")
987+
988+
989+
def mock_sample(*args, **kwargs):
990+
"""Mock the pm.sample function by returning prior predictive samples as posterior.
991+
992+
Useful for testing models that use pm.sample without running MCMC sampling.
993+
994+
Examples
995+
--------
996+
Using mock_sample with pytest
997+
998+
.. code-block:: python
999+
1000+
import pytest
1001+
1002+
import pymc as pm
1003+
from pymc.testing import mock_sample
1004+
1005+
1006+
@pytest.fixture(scope="module")
1007+
def mock_pymc_sample():
1008+
original_sample = pm.sample
1009+
pm.sample = mock_sample
1010+
1011+
yield
1012+
1013+
pm.sample = original_sample
1014+
1015+
"""
1016+
random_seed = kwargs.get("random_seed", None)
1017+
model = kwargs.get("model", None)
1018+
draws = kwargs.get("draws", 10)
1019+
n_chains = kwargs.get("chains", 1)
1020+
idata: InferenceData = pm.sample_prior_predictive(
1021+
model=model,
1022+
random_seed=random_seed,
1023+
draws=draws,
1024+
)
1025+
1026+
expanded_chains = DataArray(
1027+
np.ones(n_chains),
1028+
coords={"chain": np.arange(n_chains)},
1029+
)
1030+
idata.add_groups(
1031+
posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...)
1032+
)
1033+
del idata.prior
1034+
if "prior_predictive" in idata:
1035+
del idata.prior_predictive
1036+
return idata

0 commit comments

Comments
 (0)