Skip to content

Commit 3012b17

Browse files
committed
Add loop_over_posterior
1 parent 2bac5da commit 3012b17

File tree

3 files changed

+241
-1
lines changed

3 files changed

+241
-1
lines changed

docs/source/api/samplers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ This submodule contains functions for MCMC and forward sampling.
1515
draw
1616
compute_deterministics
1717
vectorize_over_posterior
18+
loop_over_posterior
1819
init_nuts
1920
sampling.jax.sample_blackjax_nuts
2021
sampling.jax.sample_numpyro_nuts

pymc/sampling/forward.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
walk,
4343
)
4444
from pytensor.graph.fg import FunctionGraph
45+
from pytensor.scan.basic import scan
4546
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
4647
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
4748
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -57,7 +58,12 @@
5758
from pymc.distributions.shape_utils import change_dist_size
5859
from pymc.model import Model, modelcontext
5960
from pymc.progress_bar import CustomProgress, default_progress_theme
60-
from pymc.pytensorf import compile, rvs_in_graph
61+
from pymc.pytensorf import (
62+
clone_while_sharing_some_variables,
63+
collect_default_updates,
64+
compile,
65+
rvs_in_graph,
66+
)
6167
from pymc.util import (
6268
RandomState,
6369
_get_seeds_per_chain,
@@ -68,6 +74,7 @@
6874
__all__ = (
6975
"compile_forward_sampling_function",
7076
"draw",
77+
"loop_over_posterior",
7178
"sample_posterior_predictive",
7279
"sample_prior_predictive",
7380
"vectorize_over_posterior",
@@ -1083,3 +1090,122 @@ def vectorize_over_posterior(
10831090
f"The following random variables found in the extracted graph: {remaining_rvs}"
10841091
)
10851092
return vectorized_outputs
1093+
1094+
1095+
def loop_over_posterior(
1096+
outputs: list[Variable],
1097+
posterior: xr.Dataset,
1098+
input_rvs: list[Variable],
1099+
input_tensors: Sequence[Variable] = (),
1100+
allow_rvs_in_graph: bool = True,
1101+
sample_dims: tuple[str, ...] = ("chain", "draw"),
1102+
) -> tuple[list[Variable], dict[Variable, Variable]]:
1103+
"""Loop over posterior samples of subset of input rvs.
1104+
1105+
This function creates a new graph for the supplied outputs, where the required
1106+
subset of input rvs are replaced by their posterior samples (chain and draw
1107+
dimensions, or the dimensions provided in sample_dims are flattened). The other
1108+
input tensors are kept as is.
1109+
1110+
Parameters
1111+
----------
1112+
outputs : list[Variable]
1113+
The list of variables to vectorize over the posterior samples.
1114+
posterior : xr.Dataset
1115+
The posterior samples to use as replacements for the `input_rvs`.
1116+
input_rvs : list[Variable]
1117+
The list of random variables to replace with their posterior samples.
1118+
input_tensors : Sequence[Variable]
1119+
The list of tensors to keep as is.
1120+
allow_rvs_in_graph : bool
1121+
Whether to allow random variables to be present in the graph. If False,
1122+
an error will be raised if any random variables are found in the graph. If
1123+
True, the remaining random variables will be resized to match the number of
1124+
draws from the posterior.
1125+
sample_dims : tuple[str, ...]
1126+
The dimensions of the posterior samples to use for looping the `input_rvs`.
1127+
1128+
Returns
1129+
-------
1130+
looped_outputs : list[Variable]
1131+
The looped variables, reshaped to match the original shape of the outputs, but
1132+
adding the sample_dims to the left.
1133+
updates : dict[Variable, Variable]
1134+
Dictionary of updates needed to compile the pytensor function to produce the
1135+
outputs.
1136+
1137+
Raises
1138+
------
1139+
RuntimeError
1140+
If random variables are found in the graph and `allow_rvs_in_graph` is False
1141+
ValueError
1142+
If the supplied output tensors do not depend on the requested input tensors
1143+
"""
1144+
if not (set(input_tensors) <= set(ancestors(outputs))):
1145+
raise ValueError( # pragma: no cover
1146+
"The supplied output tensors do not depend on the following requested "
1147+
f"input tensors: {set(input_tensors) - set(ancestors(outputs))}"
1148+
)
1149+
outputs_ancestors = ancestors(outputs, blockers=input_rvs)
1150+
rvs_from_posterior: list[TensorVariable] = [
1151+
cast(TensorVariable, rv) for rv in outputs_ancestors if rv in set(input_rvs)
1152+
]
1153+
independent_rvs = [
1154+
rv
1155+
for rv in rvs_in_graph(outputs)
1156+
if rv in outputs_ancestors and rv not in rvs_from_posterior
1157+
]
1158+
1159+
def step(*args):
1160+
input_values = args[: len(args) - len(input_tensors) - len(independent_rvs)]
1161+
non_sequences = args[len(args) - len(input_tensors) - len(independent_rvs) :]
1162+
1163+
# Compute output sample value for input sample values
1164+
replace = {
1165+
**dict(zip(rvs_from_posterior, input_values, strict=True)),
1166+
}
1167+
samples = clone_while_sharing_some_variables(
1168+
outputs, replace=replace, kept_variables=non_sequences
1169+
)
1170+
1171+
# Collect updates if there are RV Ops in the graph
1172+
updates = collect_default_updates(outputs=samples, inputs=input_values)
1173+
return (*samples,), updates
1174+
1175+
sequences = []
1176+
batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims])
1177+
nsamples = np.prod(batch_shape)
1178+
for rv in rvs_from_posterior:
1179+
values = posterior[rv.name].data
1180+
sequences.append(
1181+
pt.constant(
1182+
np.reshape(values, (nsamples, *values.shape[2:])),
1183+
name=rv.name,
1184+
dtype=rv.dtype,
1185+
)
1186+
)
1187+
scan_out, updates = scan(
1188+
fn=step,
1189+
sequences=sequences,
1190+
non_sequences=[*input_tensors, *independent_rvs],
1191+
n_steps=nsamples,
1192+
)
1193+
if len(outputs) == 1:
1194+
scan_out = [scan_out] # pragma: no cover
1195+
1196+
looped: list[Variable] = []
1197+
for out in scan_out:
1198+
core_shape = tuple(
1199+
[
1200+
static if static is not None else dynamic
1201+
for static, dynamic in zip(out.type.shape[1:], out.shape[1:])
1202+
]
1203+
)
1204+
looped.append(pt.reshape(out, (*batch_shape, *core_shape)))
1205+
if not allow_rvs_in_graph:
1206+
remaining_rvs = rvs_in_graph(looped)
1207+
if remaining_rvs:
1208+
raise RuntimeError(
1209+
f"The following random variables found in the extracted graph: {remaining_rvs}"
1210+
)
1211+
return looped, updates

tests/sampling/test_forward.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
compile_forward_sampling_function,
4242
get_constant_coords,
4343
get_vars_in_point_list,
44+
loop_over_posterior,
4445
observed_dependent_deterministics,
4546
vectorize_over_posterior,
4647
)
@@ -1958,3 +1959,115 @@ def test_vectorize_over_posterior_matches_sample():
19581959
atol=0.6 / np.sqrt(10000),
19591960
)
19601961
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)
1962+
1963+
1964+
def test_loop_over_posterior(
1965+
variable_to_vectorize,
1966+
input_rv_names,
1967+
allow_rvs_in_graph,
1968+
model_to_vectorize,
1969+
):
1970+
model, idata = model_to_vectorize
1971+
1972+
if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize):
1973+
with pytest.raises(
1974+
RuntimeError,
1975+
match="The following random variables found in the extracted graph",
1976+
):
1977+
loop_over_posterior(
1978+
outputs=[model[name] for name in variable_to_vectorize],
1979+
posterior=idata.posterior,
1980+
input_rvs=[model[name] for name in input_rv_names],
1981+
input_tensors=[model["d"]],
1982+
allow_rvs_in_graph=allow_rvs_in_graph,
1983+
)
1984+
else:
1985+
vectorized, _ = loop_over_posterior(
1986+
outputs=[model[name] for name in variable_to_vectorize],
1987+
posterior=idata.posterior,
1988+
input_rvs=[model[name] for name in input_rv_names],
1989+
input_tensors=[model["d"]],
1990+
allow_rvs_in_graph=allow_rvs_in_graph,
1991+
)
1992+
assert all(
1993+
vectorized_var is not model[name]
1994+
for vectorized_var, name in zip(vectorized, variable_to_vectorize)
1995+
)
1996+
assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized)
1997+
assert all(
1998+
variable_depends_on(
1999+
vectorized_var.owner.inputs[0].owner.op.inner_outputs[0], model["d"]
2000+
)
2001+
for vectorized_var in vectorized
2002+
)
2003+
inner_graph_outputs = [
2004+
vectorized_var.owner.inputs[0].owner.op.inner_outputs[i]
2005+
for i, vectorized_var in enumerate(vectorized)
2006+
]
2007+
if len(vectorized) == 2:
2008+
assert variable_depends_on(
2009+
inner_graph_outputs[variable_to_vectorize.index("z_downstream")],
2010+
inner_graph_outputs[variable_to_vectorize.index("z")],
2011+
)
2012+
if len(input_rv_names) > 0:
2013+
for input_rv_name in input_rv_names:
2014+
if input_rv_name == "x_parent":
2015+
assert len(get_var_by_name(inner_graph_outputs, input_rv_name)) == 0
2016+
else:
2017+
[vectorized_rv] = get_var_by_name(vectorized, input_rv_name)
2018+
rv_posterior = idata.posterior[input_rv_name].data
2019+
assert isinstance(vectorized_rv, TensorConstant)
2020+
np.testing.assert_equal(
2021+
np.reshape(vectorized_rv.value, rv_posterior.shape),
2022+
rv_posterior,
2023+
strict=True,
2024+
)
2025+
else:
2026+
original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize])
2027+
expected_rv_shapes = {rv.type.shape for rv in original_rvs}
2028+
rvs = rvs_in_graph(inner_graph_outputs)
2029+
assert {rv.type.shape for rv in rvs} == expected_rv_shapes
2030+
2031+
2032+
def test_loop_over_posterior_matches_sample():
2033+
rng = np.random.default_rng(1234)
2034+
with pm.Model() as model:
2035+
x = pm.Normal("x")
2036+
sigma = 0.1
2037+
obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10))
2038+
det = pm.Deterministic("det", obs + 1)
2039+
2040+
chains = 2
2041+
draws = 100
2042+
x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws))
2043+
with model:
2044+
posterior = xr.Dataset(
2045+
{
2046+
"x": xr.DataArray(
2047+
x_posterior,
2048+
dims=("chain", "draw"),
2049+
coords={"chain": np.arange(chains), "draw": np.arange(draws)},
2050+
)
2051+
}
2052+
)
2053+
idata = InferenceData(posterior=posterior)
2054+
with model:
2055+
pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234)
2056+
vectorized, updates = loop_over_posterior(
2057+
outputs=[obs, det],
2058+
posterior=posterior,
2059+
input_rvs=[x],
2060+
allow_rvs_in_graph=True,
2061+
)
2062+
[vect_obs, vect_det] = compile(
2063+
inputs=[], outputs=vectorized, random_seed=1234, updates=updates
2064+
)()
2065+
assert pp.posterior_predictive["obs"].shape == vect_obs.shape
2066+
assert pp.posterior_predictive["det"].shape == vect_det.shape
2067+
np.testing.assert_allclose(vect_obs + 1, vect_det)
2068+
np.testing.assert_allclose(
2069+
pp.posterior_predictive["obs"].mean(dim=("chain", "draw")),
2070+
vect_obs.mean(axis=(0, 1)),
2071+
atol=0.6 / np.sqrt(10000),
2072+
)
2073+
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)

0 commit comments

Comments
 (0)