From 1ece82c413efa69f7b0679fa8d2bee475b678349 Mon Sep 17 00:00:00 2001 From: Eric Ma Date: Thu, 20 Nov 2025 09:40:36 -0500 Subject: [PATCH] Fix: Deterministic variables in sample_posterior_predictive no longer cause resampling - Fix compile_forward_sampling_function to recompute Deterministic variables from trace values instead of resampling their dependencies - Add comprehensive test suite covering edge cases - Update documentation to reflect correct behavior Fixes issue where Deterministic variables in var_names would incorrectly force their dependencies to be resampled, causing incorrect variance. --- pymc/sampling/forward.py | 120 ++++++- ...test_deterministic_posterior_predictive.py | 327 ++++++++++++++++++ 2 files changed, 437 insertions(+), 10 deletions(-) create mode 100644 tests/sampling/test_deterministic_posterior_predictive.py diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index f3ae1ba098..6440512d27 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -117,6 +117,7 @@ def compile_forward_sampling_function( givens_dict: dict[Variable, Any] | None = None, constant_data: dict[str, np.ndarray] | None = None, constant_coords: set[str] | None = None, + deterministics: list[Variable] | None = None, **kwargs, ) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]: """Compile a function to draw samples, conditioned on the values of some variables. @@ -206,6 +207,8 @@ def compile_forward_sampling_function( constant_data = {} if constant_coords is None: constant_coords = set() + if deterministics is None: + deterministics = [] # We define a helper function to check if shared values match to an array def shared_value_matches(var): @@ -216,6 +219,10 @@ def shared_value_matches(var): current_shared_value = var.get_value(borrow=True) return np.array_equal(old_array_value, current_shared_value) + # Helper function to check if a node is a Deterministic + def is_deterministic(node): + return node in deterministics + # We need a function graph to walk the clients and propagate the volatile property fg = FunctionGraph(outputs=outputs, clone=False) @@ -224,9 +231,61 @@ def shared_value_matches(var): fg.outputs, deps=lambda x: x.owner.inputs if x.owner else [] ) # type: ignore[call-overload] volatile_nodes: set[Any] = set() + vars_in_trace_set = set(vars_in_trace) for node in nodes: + # Check if this is a Deterministic in outputs with all inputs in trace + # Such Deterministics should NOT be volatile and should NOT propagate volatility backwards + is_det = is_deterministic(node) + is_det_in_outputs = node in fg.outputs and is_det + det_all_inputs_in_trace = ( + is_det_in_outputs + and node.owner + and all(inp in vars_in_trace_set for inp in node.owner.inputs) + ) + + # Skip marking this Deterministic as volatile if all inputs are in trace + if det_all_inputs_in_trace: + continue + + # Check if any input is volatile + # Special handling: If this node is a direct dependency of a Deterministic output + # that has all its inputs in trace, don't mark it volatile just because of that Deterministic + has_volatile_input = False + if node.owner: + for inp in node.owner.inputs: + if inp in volatile_nodes: + # Don't propagate volatility from Deterministics that have all inputs in trace + inp_is_det_with_all_inputs = ( + is_deterministic(inp) + and inp.owner + and all( + dep_inp in vars_in_trace_set for dep_inp in inp.owner.inputs + ) + ) + if not inp_is_det_with_all_inputs: + # Also check: if this node is in trace and is a direct input to a Deterministic + # output that has all inputs in trace, don't mark it volatile + node_is_direct_input_to_safe_det = False + for output in fg.outputs: + if ( + is_deterministic(output) + and output.owner + and node in output.owner.inputs + and all( + dep_inp in vars_in_trace_set + for dep_inp in output.owner.inputs + ) + ): + node_is_direct_input_to_safe_det = True + break + + if not node_is_direct_input_to_safe_det: + has_volatile_input = True + break + if ( - node in fg.outputs + # Don't mark Deterministic outputs as volatile if all inputs are in trace + (node in fg.outputs and not det_all_inputs_in_trace) or node in givens_dict or ( # SharedVariables, except RandomState/Generators isinstance(node, SharedVariable) @@ -236,12 +295,33 @@ def shared_value_matches(var): or ( # Basic RVs that are not in the trace node in basic_rvs and node not in vars_in_trace ) - or ( # Variables that have any volatile input - node.owner and any(inp in volatile_nodes for inp in node.owner.inputs) - ) + or has_volatile_input ): volatile_nodes.add(node) + # Second pass: Unmark Deterministic outputs and their trace dependencies + # if all trace ancestors of the Deterministic are in trace + # This prevents Deterministic variables from causing their dependencies to be resampled + for output in fg.outputs: + if is_deterministic(output): + # Find all ancestors that are in basic_rvs (the actual random variables) + output_ancestors = ancestors([output], blockers=[]) + trace_ancestors = [ + anc + for anc in output_ancestors + if anc in vars_in_trace_set and anc in basic_rvs + ] + all_trace_ancestors_in_trace = all( + anc in vars_in_trace_set for anc in trace_ancestors + ) + + if all_trace_ancestors_in_trace and trace_ancestors: + # Unmark the Deterministic itself - it will be recomputed from trace values + volatile_nodes.discard(output) + # Unmark its trace ancestors - they should use trace values, not be resampled + for anc in trace_ancestors: + volatile_nodes.discard(anc) + # Collect the function inputs by walking the graph from the outputs. Inputs will be: # 1. Random variables that are not volatile # 2. Variables that have no owner and are not constant or shared @@ -277,7 +357,8 @@ def expand(node): return ( compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs), - set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled + set(basic_rvs) + & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled ) @@ -760,8 +841,13 @@ def sample_posterior_predictive( # Sampling: [x, y, z, obs] - .. danger:: Including a :func:`~pymc.Deterministic` in `var_names` may incorrectly force a random variable to be resampled, as happens with ``z`` in the following example: + .. note:: When including a :func:`~pymc.Deterministic` in `var_names`, the Deterministic variable + and its dependencies that are in the trace will be recomputed from posterior samples, + not resampled. This allows safe recomputation of Deterministic variables for new data + or coordinates while preserving the correct uncertainty quantification. + However, if a Deterministic depends on a random variable that is not in the trace, + that random variable will still be resampled. For example: .. code :: python @@ -775,8 +861,13 @@ def sample_posterior_predictive( idata = pm.sample(tune=10, draws=10, chains=2, **kwargs) + # If z is not in the trace, it will be resampled pm.sample_posterior_predictive(idata, var_names=["det_xy", "det_z"], **kwargs) - # Sampling: [z] + # Sampling: [z] # z is resampled because it's not in trace + + # But if all dependencies are in trace, no resampling occurs + pm.sample_posterior_predictive(idata, var_names=["det_xy"], **kwargs) + # Sampling: [] # No resampling, det_xy recomputed from x and y in trace Controlling the number of samples @@ -834,7 +925,9 @@ def sample_posterior_predictive( if isinstance(trace, InferenceData): _constant_data = getattr(trace, "constant_data", None) if _constant_data is not None: - trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) + trace_coords.update( + {str(k): v.data for k, v in _constant_data.coords.items()} + ) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) idata = trace observed_data = trace.get("observed_data", None) @@ -914,6 +1007,7 @@ def sample_posterior_predictive( random_seed=random_seed, constant_data=constant_data, constant_coords=constant_coords, + deterministics=model.deterministics, **compile_kwargs, ) sampler_fn = point_wrapper(_sampler_fn) @@ -941,7 +1035,11 @@ def sample_posterior_predictive( if hasattr(_trace, "_straces"): chain_idx, point_idx = np.divmod(idx, len_trace) chain_idx = chain_idx % nchain - param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) + param = ( + cast(MultiTrace, _trace) + ._straces[chain_idx] + .point(point_idx) + ) # ... or a PointList else: param = cast(PointList, _trace)[idx % (len_trace * nchain)] @@ -1061,7 +1159,9 @@ def vectorize_over_posterior( ) if rv in all_rvs ]: - rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs]) + rv_ancestors = ancestors( + [rv], blockers=[*needed_rvs, *independent_rvs, *outputs] + ) if ( rv not in needed_rvs and not ({*outputs, *independent_rvs} & set(rv_ancestors)) diff --git a/tests/sampling/test_deterministic_posterior_predictive.py b/tests/sampling/test_deterministic_posterior_predictive.py new file mode 100644 index 0000000000..bb5a2c488e --- /dev/null +++ b/tests/sampling/test_deterministic_posterior_predictive.py @@ -0,0 +1,327 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for sample_posterior_predictive with Deterministic variables. + +These tests verify that Deterministic variables are correctly recomputed from +posterior samples rather than causing their dependencies to be resampled. +""" +import logging + +import numpy as np +import numpy.testing as npt +import pytest +import pytensor + +import pymc as pm +from pymc.testing import fast_unstable_sampling_mode + + +class TestDeterministicPosteriorPredictive: + """Test that Deterministic variables don't cause resampling of dependencies.""" + + def test_deterministic_recomputed_not_resampled(self): + """ + Test that Deterministic variables are recomputed from posterior samples, + not causing their dependencies to be resampled. + + This addresses the bug where including a Deterministic in var_names + would incorrectly force its dependencies to be resampled. + """ + rng = np.random.default_rng(42) + + with pm.Model() as model: + # Hierarchical model + intercept_mu = pm.Normal("intercept_mu", mu=0, sigma=1) + intercept_sigma = pm.HalfNormal("intercept_sigma", sigma=1) + slope_mu = pm.Normal("slope_mu", mu=0, sigma=1) + slope_sigma = pm.HalfNormal("slope_sigma", sigma=1) + + intercepts = pm.Normal( + "intercepts", mu=intercept_mu, sigma=intercept_sigma, shape=(2,) + ) + slopes = pm.Normal("slopes", mu=slope_mu, sigma=slope_sigma, shape=(2,)) + + # Deterministic variable that depends on intercepts and slopes + time_coords = np.array([0.0, 12.0, 24.0, 48.0]) + mu_grid = pm.Deterministic( + "mu_grid", + intercepts[:, None] + slopes[:, None] * time_coords[None, :], + ) + + sigma = pm.HalfNormal("sigma", sigma=1) + y_obs = pm.Normal( + "y_obs", + mu=mu_grid[0, :], + sigma=sigma, + observed=rng.normal(0, 1, size=4), + ) + + # Sample + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + idata = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + return_inferencedata=True, + compute_convergence_checks=False, + random_seed=rng, + progressbar=False, + ) + + # Get posterior variance for mu_grid + mu_grid_post = idata.posterior.mu_grid.sel(test=0, time_hours=24.0) + var_post = float(mu_grid_post.var().values) + + # Use sample_posterior_predictive with Deterministic in var_names + with model: + idata_pp = pm.sample_posterior_predictive( + idata, + var_names=["mu_grid"], + predictions=True, + extend_inferencedata=False, + progressbar=False, + random_seed=rng, + ) + + # Check that mu_grid variance matches posterior + mu_grid_pred = idata_pp.predictions.mu_grid.sel(test=0, time_hours=24.0) + var_pred = float(mu_grid_pred.var().values) + + # Variance should match (within tolerance) + npt.assert_allclose(var_pred, var_post, rtol=0.1) + + # Values should be highly correlated (near 1.0) + correlation = np.corrcoef( + mu_grid_post.values.flatten(), mu_grid_pred.values.flatten() + )[0, 1] + assert correlation > 0.99, f"Correlation too low: {correlation}" + + def test_deterministic_with_random_variable_dependent(self): + """ + Test that random variables depending on Deterministic are sampled correctly. + + When y_obs depends on mu_obs (Deterministic), y_obs should be sampled + using the recomputed mu_obs, not a resampled one. + """ + rng = np.random.default_rng(43) + + with pm.Model() as model: + x = pm.Normal("x", mu=0, sigma=1) + mu_det = pm.Deterministic("mu_det", x + 1) + sigma = pm.HalfNormal("sigma", sigma=1) + y_obs = pm.Normal( + "y_obs", + mu=mu_det, + sigma=sigma, + observed=rng.normal(0, 1, size=10), + ) + + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + idata = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + return_inferencedata=True, + compute_convergence_checks=False, + random_seed=rng, + progressbar=False, + ) + + # Get posterior values + mu_det_post = idata.posterior.mu_det + sigma_post = idata.posterior.sigma + var_mu_det = float(mu_det_post.var().values) + sigma_mean_sq = float((sigma_post**2).mean().values) + + # Use sample_posterior_predictive + with model: + idata_pp = pm.sample_posterior_predictive( + idata, + var_names=["mu_det", "y_obs"], + predictions=True, + extend_inferencedata=False, + progressbar=False, + random_seed=rng, + ) + + # Check mu_det is recomputed (not resampled) + mu_det_pred = idata_pp.predictions.mu_det + var_mu_det_pred = float(mu_det_pred.var().values) + npt.assert_allclose(var_mu_det_pred, var_mu_det, rtol=0.1) + + # Check y_obs variance is correct + # Expected: var(y_obs) ≈ var(mu_det) + E[sigma^2] + y_obs_pred = idata_pp.predictions.y_obs + var_y_obs_pred = float(y_obs_pred.var().values) + expected_var = var_mu_det + sigma_mean_sq + + # Should match within reasonable tolerance (y_obs is sampled, so some variance) + npt.assert_allclose(var_y_obs_pred, expected_var, rtol=0.3) + + def test_deterministic_nested_dependencies(self): + """ + Test Deterministic with nested dependencies (Deterministic depends on + Deterministic that depends on random variables). + + Edge case: Multiple levels of Deterministic variables. + """ + rng = np.random.default_rng(44) + + with pm.Model() as model: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=0, sigma=1) + + # Nested Deterministics + det1 = pm.Deterministic("det1", x + y) + det2 = pm.Deterministic("det2", det1 * 2) + det3 = pm.Deterministic("det3", det2 + 1) + + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + idata = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + return_inferencedata=True, + compute_convergence_checks=False, + random_seed=rng, + progressbar=False, + ) + + # Get posterior variance + det3_post = idata.posterior.det3 + var_post = float(det3_post.var().values) + + # Use sample_posterior_predictive + with model: + idata_pp = pm.sample_posterior_predictive( + idata, + var_names=["det3"], + predictions=True, + extend_inferencedata=False, + progressbar=False, + random_seed=rng, + ) + + # Check variance matches + det3_pred = idata_pp.predictions.det3 + var_pred = float(det3_pred.var().values) + npt.assert_allclose(var_pred, var_post, rtol=0.1) + + # Check correlation + correlation = np.corrcoef( + det3_post.values.flatten(), det3_pred.values.flatten() + )[0, 1] + assert correlation > 0.99 + + def test_deterministic_mixed_trace_dependencies(self): + """ + Test Deterministic with mixed dependencies (some in trace, some not). + + Edge case: Deterministic depends on both variables in trace and variables + not in trace. Only variables in trace should be used from trace. + """ + rng = np.random.default_rng(45) + + with pm.Model() as model: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=0, sigma=1) + z = pm.Normal("z", mu=0, sigma=1) + + # det depends on x (in trace) and y (in trace), but z is not sampled + det = pm.Deterministic("det", x + y) + + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + idata = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + var_names=["x", "y"], # Only sample x and y + return_inferencedata=True, + compute_convergence_checks=False, + random_seed=rng, + progressbar=False, + ) + + # Get posterior variance + x_post = idata.posterior.x + y_post = idata.posterior.y + det_manual = x_post + y_post + var_manual = float(det_manual.var().values) + + # Use sample_posterior_predictive + with model: + idata_pp = pm.sample_posterior_predictive( + idata, + var_names=["det"], + predictions=True, + extend_inferencedata=False, + progressbar=False, + random_seed=rng, + ) + + # Check variance matches manual computation + det_pred = idata_pp.predictions.det + var_pred = float(det_pred.var().values) + npt.assert_allclose(var_pred, var_manual, rtol=0.1) + + def test_deterministic_no_resampling_logged(self, caplog): + """ + Test that when Deterministic is in var_names, no variables are logged + as being sampled (Sampling: []). + + This verifies that dependencies are not being resampled. + """ + rng = np.random.default_rng(46) + caplog.set_level(logging.INFO) + + with pm.Model() as model: + x = pm.Normal("x", mu=0, sigma=1) + y = pm.Normal("y", mu=0, sigma=1) + det = pm.Deterministic("det", x + y) + + with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): + idata = pm.sample( + tune=100, + draws=100, + chains=2, + step=pm.Metropolis(), + return_inferencedata=True, + compute_convergence_checks=False, + random_seed=rng, + progressbar=False, + ) + + with model: + pm.sample_posterior_predictive( + idata, + var_names=["det"], + predictions=True, + extend_inferencedata=False, + progressbar=False, + random_seed=rng, + ) + + # Check that "Sampling: []" appears in logs (no resampling) + log_messages = caplog.text + assert "Sampling: []" in log_messages or "Sampling:" not in log_messages.split( + "Sampling:" + )[-1].split("\n")[0] or len( + [msg for msg in log_messages.split("Sampling:") if "[]" in msg] + ) > 0 +