Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 132 additions & 10 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from collections.abc import Sequence
from functools import singledispatch

from pytensor.graph.op import Op
from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand All @@ -53,7 +53,7 @@
f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.",
FutureWarning,
)
return MeasurableOpMixin
return MeasurableOp

Check warning on line 56 in pymc/logprob/abstract.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/abstract.py#L56

Added line #L56 was not covered by tests

raise AttributeError(f"module {__name__} has no attribute {name}")

Expand Down Expand Up @@ -150,14 +150,7 @@
MeasurableOp.register(RandomVariable)


class MeasurableOpMixin(MeasurableOp):
"""MeasurableOp Mixin with a distinctive string representation"""

def __str__(self):
return f"Measurable{super().__str__()}"


class MeasurableElemwise(MeasurableOpMixin, Elemwise):
class MeasurableElemwise(MeasurableOp, Elemwise):
"""Base class for Measurable Elemwise variables"""

valid_scalar_types: tuple[MetaType, ...] = ()
Expand All @@ -169,3 +162,132 @@
f"Acceptable types are {self.valid_scalar_types}"
)
super().__init__(scalar_op, *args, **kwargs)

def __str__(self):
return f"Measurable{super().__str__()}"


class ValuedRV(Op):
r"""Represents the association of a measurable variable and its value.

A `ValuedVariable` node represents the pair :math:`(Y, y)`, where `y` the value at which :math:`Y`'s density
or probability mass function is evaluated.

The log-probability function takes such pairs as input, which makes these nodes in a graph an intermediate form
that serves to construct a log-probability from a model graph.


Notes
-----
The introduction of these operations achieves two goals:
1. Identify the conditioning points between multiple, potentially interdependent measurable variables,
and introduce the respective value variables in the IR graph.
2. Prevent automatic rewrites across conditioning points

About point 2. In the current framework, a RV logp cannot depend on a transformation of the value variable
of a second RV it depends on. While this is mathematically trivial, we don't have the machinery to achieve it.

The only case we do something like this is in the ad-hoc transform_value rewrite, but there we are
told explicitly what value variables must be transformed before being used in the density of dependent RVs.

For example ,the following is not supported:

```python
x_log = pt.random.normal()
x = pt.exp(x_log)
y = pt.random.normal(loc=x_log)

x_value = pt.scalar()
y_value = pt.scalar()
conditional_logprob({x: x_value, y: y_value})
```

Our framework doesn't know that the density of y should depend on a (log) transform of x_value.

Importantly, we need to prevent this limitation from being introduced automatically by our IR rewrites.
For example given the following:

```python
a_base = pm.Normal.dist()
a = a_base * 5
b = pm.Normal.dist(a * 8)

a_value = scalar()
b_value = scalar()
conditional_logp({a: a_value, b: b_value})
```

We do not want `b` to be rewritten as `pm.Normal.dist(a_base * 40)`, as it would then be disconnected from the
valued `a` associated with `pm.Normal.dist(a_base * 5). By introducing `ValuedRV` nodes the graph looks like:

```python
a_base = pm.Normal.dist()
a = valued_rv(a_base * 5, a_value)
b = valued_rv(a * 8, b_value)
```

Since, PyTensor doesn't know what to do with `ValuedRV` nodes, there is no risk of rewriting across them
and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points.
"""

def make_node(self, rv, value):
assert isinstance(rv, Variable)
assert isinstance(value, Variable)
return Apply(self, [rv, value], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("ValuedVar should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


valued_rv = ValuedRV()


class PromisedValuedRV(Op):
r"""Marks a variable as being promised a valued variable that will only be assigned by the logprob method.

Some measurable RVs like Join/MakeVector can combine multiple, potentially interdependent, RVs into a single
composite valued node. Only in the logp function is this value split and sent to each component,
but we still want to achieve the same goals that ValuedRVs achieve during the IR rewrites.

Here is an example analogous to the one described in the docstrings of ValuedRV:

```python
a_base = pt.random.normal()
a = a_base * 5
b = pt.random.normal(a * 8)
ab = pt.stack([a, b])
ab_value = pt.vector(shape=(2,))

logp(ab, ab_value)
```

The density of `ab[2]` (that is `b`) depends on `ab_value[1]` and `ab_value[0] * 8`, but this is not apparent
in the IR representation because the values of `a` and `b` are merged together, and will only be split by the logp
function (see why next). For the time being we introduce a PromisedValue to isolate the graphs of a and b, and
freezing the dependency of `b` on `a` (not `a_base`).

Now why use a new Op and not just ValuedRV? Just for convenience! In the end we still want a function from
`ab_value` to `stack([logp(a), logp(b | a)])`, and if we split the values ahead of time we wouldn't know how to
stack them later (or even know that we were supposed to).

One final point, while this achieves the same goal as introducing ValuedRVs, it already constitutes a form of inference
(knowing how/when to measure Join/MakeVectors), so we have to do it as an IR rewrite. However, we have to do it
before any other rewrites, so you'll see that the related rewrites are registered in `early_measurable_ir_rewrites_db`.

"""

def make_node(self, rv):
assert isinstance(rv, Variable)
return Apply(self, [rv], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("PromisedValuedRV should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


promised_valued_rv = PromisedValuedRV()
129 changes: 56 additions & 73 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,17 @@

import warnings

from collections import deque
from collections.abc import Sequence
from typing import TypeAlias

import numpy as np
import pytensor.tensor as pt

from pytensor import config
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
graph_inputs,
io_toposort,
)
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
from pytensor.tensor.variable import TensorVariable

Expand All @@ -65,7 +60,7 @@
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

TensorLike: TypeAlias = Variable | float | np.ndarray
Expand Down Expand Up @@ -210,8 +205,9 @@ def normal_logp(value, mu, sigma):
try:
return _logprob_helper(rv, value, **kwargs)
except NotImplementedError:
fgraph, _, _ = construct_ir_fgraph({rv: value})
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_var] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_var.owner.inputs
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
Expand Down Expand Up @@ -308,9 +304,10 @@ def normal_logcdf(value, mu, sigma):
return _logcdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _logcdf_helper(ir_rv, value, **kwargs)
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _logcdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -390,9 +387,10 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
return _icdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _icdf_helper(ir_rv, value, **kwargs)
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _icdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -476,111 +474,96 @@ def conditional_logp(
"""
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)

fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)

if extra_rewrites is not None:
extra_rewrites.rewrite(fgraph)

rv_remapper = fgraph.preserve_rv_mappings

# This is the updated random-to-value-vars map with the lifted/rewritten
# variables. The rewrites are supposed to produce new
# `MeasurableVariable`s that are amenable to `_logprob`.
updated_rv_values = rv_remapper.rv_values

# Some rewrites also transform the original value variables. This is the
# updated map from the new value variables to the original ones, which
# we want to use as the keys in the final dictionary output
original_values = rv_remapper.original_values

# When a `_logprob` has been produced for a `MeasurableVariable` node, all
# other references to it need to be replaced with its value-variable all
# throughout the `_logprob`-produced graphs. The following `dict`
# cumulatively maintains remappings for all the variables/nodes that needed
# to be recreated after replacing `MeasurableVariable`s with their
# value-variables. Since these replacements work in topological order, all
# the necessary value-variable replacements should be present for each
# node.
replacements = updated_rv_values.copy()
# Walk the graph from its inputs to its outputs and construct the
# log-probability
replacements = {}

# To avoid cloning the value variables (or ancestors of value variables),
# we map them to themselves in the `replacements` `dict`
# (i.e. entries already existing in `replacements` aren't cloned)
replacements.update(
{
v: v
for v in ancestors(rv_values.values())
if (not isinstance(v, Constant) and v not in replacements)
}
{v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
)

# Walk the graph from its inputs to its outputs and construct the
# log-probability
q = deque(fgraph.toposort())
logprob_vars = {}

while q:
node = q.popleft()
values_to_logprobs = {}
original_values = tuple(rv_values.values())

# TODO: This seems too convoluted, can we just replace all RVs by their values,
# except for the fgraph outputs (for which we want to call _logprob on)?
for node in fgraph.toposort():
if not isinstance(node.op, MeasurableOp):
continue

q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
valued_nodes = get_related_valued_nodes(node, fgraph)

if not q_values:
if not valued_nodes:
continue

node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
node_values = [valued_var.inputs[1] for valued_var in valued_nodes]
node_output_idxs = [
fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes
]

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes that follow.
for node_rv, node_value in zip(node_rvs, node_values):
replacements[node_rv] = node_value

remapped_vars = replace_vars_in_graphs(
graphs=q_values + list(node.inputs),
graphs=node_values + list(node.inputs),
replacements=replacements,
)
q_values = remapped_vars[: len(q_values)]
q_rv_inputs = remapped_vars[len(q_values) :]
node_values = remapped_vars[: len(node_values)]
node_inputs = remapped_vars[len(node_values) :]

q_logprob_vars = _logprob(
node_logprobs = _logprob(
node.op,
q_values,
*q_rv_inputs,
node_values,
*node_inputs,
**kwargs,
)

if not isinstance(q_logprob_vars, list | tuple):
q_logprob_vars = [q_logprob_vars]
if not isinstance(node_logprobs, list | tuple):
node_logprobs = [node_logprobs]

for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars):
q_value_var = original_values[q_value_var]
for node_output_idx, node_value, node_logprob in zip(
node_output_idxs, node_values, node_logprobs
):
original_value = original_values[node_output_idx]

if q_value_var.name:
q_logprob_var.name = f"{q_value_var.name}_logprob"
if original_value.name:
node_logprob.name = f"{original_value.name}_logprob"

if q_value_var in logprob_vars:
if original_value in values_to_logprobs:
raise ValueError(
f"More than one logprob term was assigned to the value var {q_value_var}"
f"More than one logprob term was assigned to the value var {original_value}"
)

logprob_vars[q_value_var] = q_logprob_var

# Recompute test values for the changes introduced by the replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
values_to_logprobs[original_value] = node_logprob

missing_value_terms = set(original_values.values()) - set(logprob_vars.keys())
missing_value_terms = set(original_values) - set(values_to_logprobs)
if missing_value_terms:
raise RuntimeError(
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
)

logprob_expressions = list(logprob_vars.values())
cleanup_ir(logprob_expressions)
logprobs = list(values_to_logprobs.values())
cleanup_ir(logprobs)

if warn_rvs:
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions)
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
if rvs_in_logp_expressions:
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)

return logprob_vars
return values_to_logprobs


def transformed_conditional_logp(
Expand Down
Loading
Loading