diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 14a39c376..6ce53f418 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -15,6 +15,7 @@ This submodule contains functions for MCMC and forward sampling. draw compute_deterministics vectorize_over_posterior + loop_over_posterior init_nuts sampling.jax.sample_blackjax_nuts sampling.jax.sample_numpyro_nuts diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 3f90a69ff..bdb590f91 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -58,6 +58,7 @@ __all__ = [ "CallableTensor", + "clone_while_sharing_some_variables", "compile", "cont_inputs", "convert_data", @@ -1041,3 +1042,42 @@ def normalize_rng_param(rng: None | Variable) -> Variable: "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) return rng + + +def clone_while_sharing_some_variables( + outputs: list[Variable], + kept_variables: Sequence[Variable] = (), + replace: dict[Variable, Variable] | None = None, +) -> list[Variable]: + """Clone graphs, applying replacements while preserving some original variables. + + Parameters + ---------- + outputs : list[Variable] + The list of variables to clone. + kept_variables : Sequence[Variable] + The set of variables to preserve in the cloned graph. + replace : dict[Variable, Variable] + A dictionary of variables to replace in the cloned graph. + The keys are the variables to replace, and the values are the new variables + to use in their place. + + Returns + ------- + list[Variable] + The cloned graphs with the replacements applied. + """ + replace_dict = replace or {} + + memo = {rv: rv for rv in kept_variables} + clone_map = clone_get_equiv( + [], + outputs, + memo=memo, + ) + + replace_keys = [clone_map.get(key, key) for key in replace_dict] + replace_values = replace_vars_in_graphs(list(replace_dict.values()), clone_map) + fg = FunctionGraph(None, [clone_map[o] for o in outputs], clone=False) + fg.replace_all(list(zip(replace_keys, replace_values)), import_missing=True) + return fg.outputs diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c53dde90f..27c4fae96 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -42,6 +42,7 @@ walk, ) from pytensor.graph.fg import FunctionGraph +from pytensor.scan.basic import scan from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -57,7 +58,12 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.model import Model, modelcontext from pymc.progress_bar import CustomProgress, default_progress_theme -from pymc.pytensorf import compile, rvs_in_graph +from pymc.pytensorf import ( + clone_while_sharing_some_variables, + collect_default_updates, + compile, + rvs_in_graph, +) from pymc.util import ( RandomState, _get_seeds_per_chain, @@ -68,6 +74,7 @@ __all__ = ( "compile_forward_sampling_function", "draw", + "loop_over_posterior", "sample_posterior_predictive", "sample_prior_predictive", "vectorize_over_posterior", @@ -1083,3 +1090,122 @@ def vectorize_over_posterior( f"The following random variables found in the extracted graph: {remaining_rvs}" ) return vectorized_outputs + + +def loop_over_posterior( + outputs: list[Variable], + posterior: xr.Dataset, + input_rvs: list[Variable], + input_tensors: Sequence[Variable] = (), + allow_rvs_in_graph: bool = True, + sample_dims: tuple[str, ...] = ("chain", "draw"), +) -> tuple[list[Variable], dict[Variable, Variable]]: + """Loop over posterior samples of subset of input rvs. + + This function creates a new graph for the supplied outputs, where the required + subset of input rvs are replaced by their posterior samples (chain and draw + dimensions, or the dimensions provided in sample_dims are flattened). The other + input tensors are kept as is. + + Parameters + ---------- + outputs : list[Variable] + The list of variables to vectorize over the posterior samples. + posterior : xr.Dataset + The posterior samples to use as replacements for the `input_rvs`. + input_rvs : list[Variable] + The list of random variables to replace with their posterior samples. + input_tensors : Sequence[Variable] + The list of tensors to keep as is. + allow_rvs_in_graph : bool + Whether to allow random variables to be present in the graph. If False, + an error will be raised if any random variables are found in the graph. If + True, the remaining random variables will be resized to match the number of + draws from the posterior. + sample_dims : tuple[str, ...] + The dimensions of the posterior samples to use for looping the `input_rvs`. + + Returns + ------- + looped_outputs : list[Variable] + The looped variables, reshaped to match the original shape of the outputs, but + adding the sample_dims to the left. + updates : dict[Variable, Variable] + Dictionary of updates needed to compile the pytensor function to produce the + outputs. + + Raises + ------ + RuntimeError + If random variables are found in the graph and `allow_rvs_in_graph` is False + ValueError + If the supplied output tensors do not depend on the requested input tensors + """ + if not (set(input_tensors) <= set(ancestors(outputs))): + raise ValueError( # pragma: no cover + "The supplied output tensors do not depend on the following requested " + f"input tensors: {set(input_tensors) - set(ancestors(outputs))}" + ) + outputs_ancestors = ancestors(outputs, blockers=input_rvs) + rvs_from_posterior: list[TensorVariable] = [ + cast(TensorVariable, rv) for rv in outputs_ancestors if rv in set(input_rvs) + ] + independent_rvs = [ + rv + for rv in rvs_in_graph(outputs) + if rv in outputs_ancestors and rv not in rvs_from_posterior + ] + + def step(*args): + input_values = args[: len(args) - len(input_tensors) - len(independent_rvs)] + non_sequences = args[len(args) - len(input_tensors) - len(independent_rvs) :] + + # Compute output sample value for input sample values + replace = { + **dict(zip(rvs_from_posterior, input_values, strict=True)), + } + samples = clone_while_sharing_some_variables( + outputs, replace=replace, kept_variables=non_sequences + ) + + # Collect updates if there are RV Ops in the graph + updates = collect_default_updates(outputs=samples, inputs=input_values) + return (*samples,), updates + + sequences = [] + batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims]) + nsamples = np.prod(batch_shape) + for rv in rvs_from_posterior: + values = posterior[rv.name].data + sequences.append( + pt.constant( + np.reshape(values, (nsamples, *values.shape[2:])), + name=rv.name, + dtype=rv.dtype, + ) + ) + scan_out, updates = scan( + fn=step, + sequences=sequences, + non_sequences=[*input_tensors, *independent_rvs], + n_steps=nsamples, + ) + if len(outputs) == 1: + scan_out = [scan_out] # pragma: no cover + + looped: list[Variable] = [] + for out in scan_out: + core_shape = tuple( + [ + static if static is not None else dynamic + for static, dynamic in zip(out.type.shape[1:], out.shape[1:]) + ] + ) + looped.append(pt.reshape(out, (*batch_shape, *core_shape))) + if not allow_rvs_in_graph: + remaining_rvs = rvs_in_graph(looped) + if remaining_rvs: + raise RuntimeError( + f"The following random variables found in the extracted graph: {remaining_rvs}" + ) + return looped, updates diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 3dd30e14f..1cce326d3 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -41,6 +41,7 @@ compile_forward_sampling_function, get_constant_coords, get_vars_in_point_list, + loop_over_posterior, observed_dependent_deterministics, vectorize_over_posterior, ) @@ -1958,3 +1959,115 @@ def test_vectorize_over_posterior_matches_sample(): atol=0.6 / np.sqrt(10000), ) assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1) + + +def test_loop_over_posterior( + variable_to_vectorize, + input_rv_names, + allow_rvs_in_graph, + model_to_vectorize, +): + model, idata = model_to_vectorize + + if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize): + with pytest.raises( + RuntimeError, + match="The following random variables found in the extracted graph", + ): + loop_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + input_tensors=[model["d"]], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + else: + vectorized, _ = loop_over_posterior( + outputs=[model[name] for name in variable_to_vectorize], + posterior=idata.posterior, + input_rvs=[model[name] for name in input_rv_names], + input_tensors=[model["d"]], + allow_rvs_in_graph=allow_rvs_in_graph, + ) + assert all( + vectorized_var is not model[name] + for vectorized_var, name in zip(vectorized, variable_to_vectorize) + ) + assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized) + assert all( + variable_depends_on( + vectorized_var.owner.inputs[0].owner.op.inner_outputs[0], model["d"] + ) + for vectorized_var in vectorized + ) + inner_graph_outputs = [ + vectorized_var.owner.inputs[0].owner.op.inner_outputs[i] + for i, vectorized_var in enumerate(vectorized) + ] + if len(vectorized) == 2: + assert variable_depends_on( + inner_graph_outputs[variable_to_vectorize.index("z_downstream")], + inner_graph_outputs[variable_to_vectorize.index("z")], + ) + if len(input_rv_names) > 0: + for input_rv_name in input_rv_names: + if input_rv_name == "x_parent": + assert len(get_var_by_name(inner_graph_outputs, input_rv_name)) == 0 + else: + [vectorized_rv] = get_var_by_name(vectorized, input_rv_name) + rv_posterior = idata.posterior[input_rv_name].data + assert isinstance(vectorized_rv, TensorConstant) + np.testing.assert_equal( + np.reshape(vectorized_rv.value, rv_posterior.shape), + rv_posterior, + strict=True, + ) + else: + original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize]) + expected_rv_shapes = {rv.type.shape for rv in original_rvs} + rvs = rvs_in_graph(inner_graph_outputs) + assert {rv.type.shape for rv in rvs} == expected_rv_shapes + + +def test_loop_over_posterior_matches_sample(): + rng = np.random.default_rng(1234) + with pm.Model() as model: + x = pm.Normal("x") + sigma = 0.1 + obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10)) + det = pm.Deterministic("det", obs + 1) + + chains = 2 + draws = 100 + x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws)) + with model: + posterior = xr.Dataset( + { + "x": xr.DataArray( + x_posterior, + dims=("chain", "draw"), + coords={"chain": np.arange(chains), "draw": np.arange(draws)}, + ) + } + ) + idata = InferenceData(posterior=posterior) + with model: + pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234) + vectorized, updates = loop_over_posterior( + outputs=[obs, det], + posterior=posterior, + input_rvs=[x], + allow_rvs_in_graph=True, + ) + [vect_obs, vect_det] = compile( + inputs=[], outputs=vectorized, random_seed=1234, updates=updates + )() + assert pp.posterior_predictive["obs"].shape == vect_obs.shape + assert pp.posterior_predictive["det"].shape == vect_det.shape + np.testing.assert_allclose(vect_obs + 1, vect_det) + np.testing.assert_allclose( + pp.posterior_predictive["obs"].mean(dim=("chain", "draw")), + vect_obs.mean(axis=(0, 1)), + atol=0.6 / np.sqrt(10000), + ) + assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index d172c61a4..068ec0bc3 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,7 +24,7 @@ from pytensor import scan, shared from pytensor.compile import UnusedInputError from pytensor.compile.builders import OpFromGraph -from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph.basic import Variable, ancestors, equal_computations, get_var_by_name from pytensor.tensor.subtensor import AdvancedIncSubtensor import pymc as pm @@ -36,6 +36,7 @@ from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import ( PointFunc, + clone_while_sharing_some_variables, collect_default_updates, compile, constant_fold, @@ -785,3 +786,24 @@ def test_pickle_point_func(): np.testing.assert_allclose( point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]}) ) + + +def test_clone_while_sharing_some_variables(): + with pm.Model() as model: + x = pm.Normal("x") + d = pm.Data("d", np.array([1, 2, 3])) + obs = pm.Data("obs", np.ones_like(d.get_value())) + y = pm.Deterministic("y", x * d) + z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs) + + kept_variables = [*model.free_RVs, *model.data_vars] + d_replace = pt.zeros_like(d.get_value()) + d_replace.name = "d" + z_clone = clone_while_sharing_some_variables([z], kept_variables, {d: d_replace})[0] + assert z_clone is not z + cloned_ancestors = list(ancestors([z_clone])) + for kept_var in [x, obs]: + assert kept_var in cloned_ancestors + for different_var in [d, y]: + assert different_var not in cloned_ancestors + assert np.all(get_var_by_name([z_clone], "d")[0].eval() == 0)