Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 110 additions & 10 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def compile_forward_sampling_function(
givens_dict: dict[Variable, Any] | None = None,
constant_data: dict[str, np.ndarray] | None = None,
constant_coords: set[str] | None = None,
deterministics: list[Variable] | None = None,
**kwargs,
) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
Expand Down Expand Up @@ -206,6 +207,8 @@ def compile_forward_sampling_function(
constant_data = {}
if constant_coords is None:
constant_coords = set()
if deterministics is None:
deterministics = []

# We define a helper function to check if shared values match to an array
def shared_value_matches(var):
Expand All @@ -216,6 +219,10 @@ def shared_value_matches(var):
current_shared_value = var.get_value(borrow=True)
return np.array_equal(old_array_value, current_shared_value)

# Helper function to check if a node is a Deterministic
def is_deterministic(node):
return node in deterministics

# We need a function graph to walk the clients and propagate the volatile property
fg = FunctionGraph(outputs=outputs, clone=False)

Expand All @@ -224,9 +231,61 @@ def shared_value_matches(var):
fg.outputs, deps=lambda x: x.owner.inputs if x.owner else []
) # type: ignore[call-overload]
volatile_nodes: set[Any] = set()
vars_in_trace_set = set(vars_in_trace)
for node in nodes:
# Check if this is a Deterministic in outputs with all inputs in trace
# Such Deterministics should NOT be volatile and should NOT propagate volatility backwards
is_det = is_deterministic(node)
is_det_in_outputs = node in fg.outputs and is_det
det_all_inputs_in_trace = (
is_det_in_outputs
and node.owner
and all(inp in vars_in_trace_set for inp in node.owner.inputs)
)

# Skip marking this Deterministic as volatile if all inputs are in trace
if det_all_inputs_in_trace:
continue

# Check if any input is volatile
# Special handling: If this node is a direct dependency of a Deterministic output
# that has all its inputs in trace, don't mark it volatile just because of that Deterministic
has_volatile_input = False
if node.owner:
for inp in node.owner.inputs:
if inp in volatile_nodes:
# Don't propagate volatility from Deterministics that have all inputs in trace
inp_is_det_with_all_inputs = (
is_deterministic(inp)
and inp.owner
and all(
dep_inp in vars_in_trace_set for dep_inp in inp.owner.inputs
)
)
if not inp_is_det_with_all_inputs:
# Also check: if this node is in trace and is a direct input to a Deterministic
# output that has all inputs in trace, don't mark it volatile
node_is_direct_input_to_safe_det = False
for output in fg.outputs:
if (
is_deterministic(output)
and output.owner
and node in output.owner.inputs
and all(
dep_inp in vars_in_trace_set
for dep_inp in output.owner.inputs
)
):
node_is_direct_input_to_safe_det = True
break

if not node_is_direct_input_to_safe_det:
has_volatile_input = True
break

if (
node in fg.outputs
# Don't mark Deterministic outputs as volatile if all inputs are in trace
(node in fg.outputs and not det_all_inputs_in_trace)
or node in givens_dict
or ( # SharedVariables, except RandomState/Generators
isinstance(node, SharedVariable)
Expand All @@ -236,12 +295,33 @@ def shared_value_matches(var):
or ( # Basic RVs that are not in the trace
node in basic_rvs and node not in vars_in_trace
)
or ( # Variables that have any volatile input
node.owner and any(inp in volatile_nodes for inp in node.owner.inputs)
)
or has_volatile_input
):
volatile_nodes.add(node)

# Second pass: Unmark Deterministic outputs and their trace dependencies
# if all trace ancestors of the Deterministic are in trace
# This prevents Deterministic variables from causing their dependencies to be resampled
for output in fg.outputs:
if is_deterministic(output):
# Find all ancestors that are in basic_rvs (the actual random variables)
output_ancestors = ancestors([output], blockers=[])
trace_ancestors = [
anc
for anc in output_ancestors
if anc in vars_in_trace_set and anc in basic_rvs
]
all_trace_ancestors_in_trace = all(
anc in vars_in_trace_set for anc in trace_ancestors
)

if all_trace_ancestors_in_trace and trace_ancestors:
# Unmark the Deterministic itself - it will be recomputed from trace values
volatile_nodes.discard(output)
# Unmark its trace ancestors - they should use trace values, not be resampled
for anc in trace_ancestors:
volatile_nodes.discard(anc)

# Collect the function inputs by walking the graph from the outputs. Inputs will be:
# 1. Random variables that are not volatile
# 2. Variables that have no owner and are not constant or shared
Expand Down Expand Up @@ -277,7 +357,8 @@ def expand(node):

return (
compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
set(basic_rvs)
& (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
)


Expand Down Expand Up @@ -760,8 +841,13 @@ def sample_posterior_predictive(
# Sampling: [x, y, z, obs]


.. danger:: Including a :func:`~pymc.Deterministic` in `var_names` may incorrectly force a random variable to be resampled, as happens with ``z`` in the following example:
.. note:: When including a :func:`~pymc.Deterministic` in `var_names`, the Deterministic variable
and its dependencies that are in the trace will be recomputed from posterior samples,
not resampled. This allows safe recomputation of Deterministic variables for new data
or coordinates while preserving the correct uncertainty quantification.

However, if a Deterministic depends on a random variable that is not in the trace,
that random variable will still be resampled. For example:

.. code :: python

Expand All @@ -775,8 +861,13 @@ def sample_posterior_predictive(

idata = pm.sample(tune=10, draws=10, chains=2, **kwargs)

# If z is not in the trace, it will be resampled
pm.sample_posterior_predictive(idata, var_names=["det_xy", "det_z"], **kwargs)
# Sampling: [z]
# Sampling: [z] # z is resampled because it's not in trace

# But if all dependencies are in trace, no resampling occurs
pm.sample_posterior_predictive(idata, var_names=["det_xy"], **kwargs)
# Sampling: [] # No resampling, det_xy recomputed from x and y in trace


Controlling the number of samples
Expand Down Expand Up @@ -834,7 +925,9 @@ def sample_posterior_predictive(
if isinstance(trace, InferenceData):
_constant_data = getattr(trace, "constant_data", None)
if _constant_data is not None:
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
trace_coords.update(
{str(k): v.data for k, v in _constant_data.coords.items()}
)
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
idata = trace
observed_data = trace.get("observed_data", None)
Expand Down Expand Up @@ -914,6 +1007,7 @@ def sample_posterior_predictive(
random_seed=random_seed,
constant_data=constant_data,
constant_coords=constant_coords,
deterministics=model.deterministics,
**compile_kwargs,
)
sampler_fn = point_wrapper(_sampler_fn)
Expand Down Expand Up @@ -941,7 +1035,11 @@ def sample_posterior_predictive(
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
param = (
cast(MultiTrace, _trace)
._straces[chain_idx]
.point(point_idx)
)
# ... or a PointList
else:
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
Expand Down Expand Up @@ -1061,7 +1159,9 @@ def vectorize_over_posterior(
)
if rv in all_rvs
]:
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
rv_ancestors = ancestors(
[rv], blockers=[*needed_rvs, *independent_rvs, *outputs]
)
if (
rv not in needed_rvs
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
Expand Down
Loading
Loading