diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f02afdc62..538edfbba 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -116,6 +116,7 @@ jobs: tests/logprob/test_censoring.py tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py + tests/logprob/test_linalg.py tests/logprob/test_mixture.py tests/logprob/test_order.py tests/logprob/test_rewriting.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index aaa8b2052..6b4911ae6 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -49,6 +49,7 @@ import pymc.logprob.censoring import pymc.logprob.cumsum import pymc.logprob.checks +import pymc.logprob.linalg import pymc.logprob.mixture import pymc.logprob.order import pymc.logprob.scan diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index f47c39e2b..281b4fb18 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -43,6 +43,7 @@ from pytensor.graph import Apply, Op, Variable from pytensor.graph.utils import MetaType from pytensor.tensor import TensorVariable +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable @@ -168,6 +169,10 @@ def __str__(self): return f"Measurable{super().__str__()}" +class MeasurableBlockwise(MeasurableOp, Blockwise): + """Base class for Measurable Blockwise variables.""" + + class ValuedRV(Op): r"""Represents the association of a measurable variable and its value. diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 8c76fd8ae..7753678d2 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -503,7 +503,7 @@ def conditional_logp( if not isinstance(node.op, MeasurableOp): continue - valued_nodes = get_related_valued_nodes(node, fgraph) + valued_nodes = get_related_valued_nodes(fgraph, node) if not valued_nodes: continue diff --git a/pymc/logprob/linalg.py b/pymc/logprob/linalg.py new file mode 100644 index 000000000..226b24a07 --- /dev/null +++ b/pymc/logprob/linalg.py @@ -0,0 +1,102 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.tensor as pt + +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.math import _matrix_matrix_matmul + +from pymc.logprob.abstract import MeasurableBlockwise, MeasurableOp, _logprob, _logprob_helper +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables + + +class MeasurableMatMul(MeasurableBlockwise): + """Measurable matrix multiplication operation.""" + + right_measurable: bool + + def __init__(self, measurable_right: bool, **kwargs): + self.right_measurable = measurable_right + super().__init__(**kwargs) + + +@_logprob.register(MeasurableMatMul) +def logprob_measurable_matmul(op, values, l, r): # noqa: E741 + [y_value] = values + if op.right_measurable: + A, x = l, r + x_value = pt.linalg.solve(A, y_value) + else: + x, A = l, r + x_value = pt.linalg.solve(A.mT, y_value.mT).mT + + x_logp = _logprob_helper(x, x_value) + + # The operation has a support dimensionality of 2 + # We need to reduce it if it's still present in the base logp + if x_logp.type.ndim == x_value.type.ndim: + x_logp = pt.sum(x_logp, axis=(-1, -2)) + elif x_logp.type.ndim == x_value.type.ndim - 1: + x_logp = pt.sum(x_logp, axis=-1) + + _, log_abs_jac_det = pt.linalg.slogdet(A) + + return x_logp - log_abs_jac_det + + +@node_rewriter(tracks=[_matrix_matrix_matmul]) +def find_measurable_matmul(fgraph, node): + """Find measurable matrix-matrix multiplication operations.""" + if isinstance(node.op, MeasurableOp): + return None + + [out] = node.outputs + [l, r] = node.inputs # noqa: E741 + + # Check that not both a and r are measurable + measurable_inputs = filter_measurable_variables([l, r]) + if len(measurable_inputs) != 1: + return None + + [measurable_input] = measurable_inputs + + # Check the measurable input is not broadcasted + if measurable_input.type.broadcastable[:-2] != out.type.broadcastable[:-2]: + return None + + measurable_right = measurable_input is r + A = l if measurable_right else r + + # Check if the static shape already reveals a non-square matrix, + if ( + A.type.shape[-1] is not None + and A.type.shape[-2] is not None + and A.type.shape[-1] != A.type.shape[-2] + ): + return None + + # Check the other input is not potentially measurable + if check_potential_measurability([A]): + return None + + measurable_matmul = MeasurableMatMul(measurable_right=measurable_right, **node.op._props_dict()) + return [measurable_matmul(l, r)] + + +measurable_ir_rewrites_db.register( + find_measurable_matmul.__name__, + find_measurable_matmul, + "basic", + "linalg", +) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 55e506ad9..1ebb29638 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -468,7 +468,7 @@ def split_valued_ifelse(fgraph, node): # Single outputs IfElse return None - valued_output_nodes = get_related_valued_nodes(node, fgraph) + valued_output_nodes = get_related_valued_nodes(fgraph, node) if not valued_output_nodes: return None diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index cd390e13a..76baf31df 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -152,7 +152,11 @@ def remove_DiracDelta(fgraph, node): logprob_rewrites_db.register( "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" ) -logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") +logprob_rewrites_db.register( + "pre-canonicalize", + optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), + "basic", +) # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require @@ -175,7 +179,11 @@ def remove_DiracDelta(fgraph, node): ) -logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") +logprob_rewrites_db.register( + "post-canonicalize", + optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"), + "basic", +) # Rewrites that remove IR Ops cleanup_ir_rewrites_db = LocalGroupDB() diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index ecd04b9c7..8626c20c6 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -421,7 +421,7 @@ def find_measurable_scans(fgraph, node): # Find outputs of scan that are directly valued. # These must be mapping outputs, such as `outputs_info = [None]` (i.e, no recurrence nit_sot outputs) direct_valued_outputs = [ - valued_node.inputs[0] for valued_node in get_related_valued_nodes(node, fgraph) + valued_node.inputs[0] for valued_node in get_related_valued_nodes(fgraph, node) ] if not all(valued_out in scan_args.outer_out_nit_sot for valued_out in direct_valued_outputs): return None @@ -434,7 +434,7 @@ def find_measurable_scans(fgraph, node): client.outputs[0] for out in node.outputs for client, _ in fgraph.clients[out] - if (isinstance(client.op, Subtensor) and get_related_valued_nodes(client, fgraph)) + if (isinstance(client.op, Subtensor) and get_related_valued_nodes(fgraph, client)) ] indirect_valued_outputs = [out.owner.inputs[0] for out in sliced_valued_outputs] if not all( diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index ec4f5c65a..5503ce32b 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -41,13 +41,19 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor import TensorVariable from pytensor.tensor.basic import Join, MakeVector -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, ) -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv +from pymc.logprob.abstract import ( + MeasurableOp, + ValuedRV, + _logprob, + _logprob_helper, + promised_valued_rv, +) from pymc.logprob.rewriting import ( assume_valued_outputs, early_measurable_ir_rewrites_db, @@ -57,6 +63,7 @@ from pymc.logprob.utils import ( check_potential_measurability, filter_measurable_variables, + get_related_valued_nodes, replace_rvs_by_values, ) from pymc.pytensorf import constant_fold @@ -183,6 +190,9 @@ class MeasurableDimShuffle(MeasurableOp, DimShuffle): # find it locally and fails when a new `Op` is initialized c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) # type: ignore[arg-type] + def __str__(self): + return f"Measurable{super().__str__()}" + @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): @@ -215,29 +225,66 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): return raw_logp.dimshuffle(redo_ds) +def _elemwise_univariate_chain(fgraph, node) -> bool: + # Check whether only Elemwise operations connect a base univariate RV to the valued node through var. + from pymc.distributions.distribution import SymbolicRandomVariable + from pymc.logprob.transforms import MeasurableTransform + + [inp] = node.inputs + [out] = node.outputs + + def elemwise_root(var: TensorVariable) -> TensorVariable | None: + if isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable): + return var + elif isinstance(var.owner.op, MeasurableTransform): + return elemwise_root(var.owner.inputs[var.owner.op.measurable_input_idx]) + else: + return None + + # Check that the root is a univariate distribution linked by only elemwise operations + root = elemwise_root(inp) + if root is None: + return False + elif root.owner.op.ndim_supp != 0: + # This is still fine if the variable is directly valued + return any(get_related_valued_nodes(fgraph, node)) + + def elemwise_leaf(var: TensorVariable, clients=fgraph.clients) -> bool: + var_clients = clients[var] + if len(var_clients) != 1: + return False + [(client, _)] = var_clients + if isinstance(client.op, ValuedRV): + return True + elif isinstance(client.op, Elemwise) and len(client.outputs) == 1: + return elemwise_leaf(client.outputs[0]) + else: + return False + + # Check that the path to the valued node consists only of elemwise operations + return elemwise_leaf(out) + + @node_rewriter([DimShuffle]) def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: r"""Find `Dimshuffle`\s for which a `logprob` can be computed.""" - from pymc.distributions.distribution import SymbolicRandomVariable - if isinstance(node.op, MeasurableOp): return None if not filter_measurable_variables(node.inputs): return None - base_var = node.inputs[0] + # In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise + # operations separate it from the valued node. Further transformations likely need to know where + # the support axes are for a correct implementation (and thus assume they are the rightmost axes). + # TODO: When we include the support axis as meta information in each intermediate MeasurableVariable, + # we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360) + if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain( + fgraph, node + ): + return None - # We can only apply this rewrite directly to `RandomVariable`s, as those are - # the only `Op`s for which we always know the support axis. Other measurable - # variables can have arbitrary support axes (e.g., if they contain separate - # `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s - # should still be supported as long as the `DimShuffle`s can be merged/ - # lifted towards the base RandomVariable. - # TODO: If we include the support axis as meta information in each - # intermediate MeasurableVariable, we can lift this restriction. - if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable): - return None # pragma: no cover + base_var = node.inputs[0] measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( base_var diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index f093ddbf2..1b5d4cd81 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -147,7 +147,7 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: return None rv_node = node.inputs[0].owner - valued_nodes = get_related_valued_nodes(rv_node, fgraph) + valued_nodes = get_related_valued_nodes(fgraph, rv_node) rvs = [valued_var.inputs[0] for valued_var in valued_nodes] values = [valued_var.inputs[1] for valued_var in valued_nodes] transforms = [values_to_transforms.get(value, None) for value in values] diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index e96426fbe..9865226e4 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -320,7 +320,7 @@ def find_negated_var(var): return None -def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]: +def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]: """Get all ValuedVars related to the same RV node. Returns diff --git a/pyproject.toml b/pyproject.toml index 3674e2272..a8ffb06ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ ignore = [ "D101", # Missing docstring in public class "D102", # Missing docstring in public method "D103", # Missing docstring in public function + "D105", # Missing docstring in magic method ] [tool.ruff.lint.pydocstyle] diff --git a/tests/logprob/test_linalg.py b/tests/logprob/test_linalg.py new file mode 100644 index 000000000..047a0312b --- /dev/null +++ b/tests/logprob/test_linalg.py @@ -0,0 +1,85 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor.tensor.type import tensor + +from pymc.distributions import MatrixNormal, MvNormal, Normal +from pymc.logprob.basic import logp + + +@pytest.mark.parametrize("univariate", [True, False]) +@pytest.mark.parametrize("batch_shape", [(), (3,)]) +def test_matrix_vector_transform(univariate, batch_shape): + rng = np.random.default_rng(755) + + μ = rng.normal(size=(*batch_shape, 2)) + if univariate: + σ = np.abs(rng.normal(size=(*batch_shape, 2))) + Σ = np.eye(2) * (σ**2)[..., None] + x = Normal.dist(mu=μ, sigma=σ) + else: + A = rng.normal(size=(*batch_shape, 2, 2)) + Σ = np.swapaxes(A, -1, -2) @ A + x = MvNormal.dist(mu=μ, cov=Σ) + + c = rng.normal(size=(*batch_shape, 2)) + B = rng.normal(size=(*batch_shape, 2, 2)) + y = c + (B @ x[..., None]).squeeze(-1) + + # An affine transformed MvNormal is still a MvNormal + # https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Affine_transformation + ref_dist = MvNormal.dist( + mu=c + (B @ μ[..., None]).squeeze(-1), cov=B @ Σ @ np.swapaxes(B, -1, -2) + ) + test_y = rng.normal(size=(*batch_shape, 2)) + np.testing.assert_allclose( + logp(y, test_y).eval(), + logp(ref_dist, test_y).eval(), + ) + + +def test_matrix_matrix_transform(): + rng = np.random.default_rng(46) + + n, p = 2, 3 + M = rng.normal(size=(n, p)) + A = rng.normal(size=(n, n)) * 0.1 + U = A.T @ A + B = rng.normal(size=(p, p)) * 0.1 + V = B.T @ B + X = MatrixNormal.dist(mu=M, rowcov=U, colcov=V) + + D = rng.normal(size=(n, n)) + C = rng.normal(size=(p, p)) + Y = D @ X @ C + + # A linearly transformed MatrixNormal is still a MatrixNormal + # https://en.wikipedia.org/wiki/Matrix_normal_distribution#Transformation + ref_dist = MatrixNormal.dist(mu=D @ M @ C, rowcov=D @ U @ D.T, colcov=C.T @ V @ C) + test_Y = rng.normal(size=(n, p)) + np.testing.assert_allclose( + logp(Y, test_Y).eval(), + logp(ref_dist, test_Y).eval(), + rtol=1e-5, + ) + + +def test_broadcasted_matmul_fails(): + x = Normal.dist(size=(3, 2)) + A = tensor("A", shape=(4, 3, 3)) + y = A @ x + with pytest.raises(NotImplementedError): + logp(y, y.type()) diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index 13ac8cfff..e118ed69f 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -309,10 +309,7 @@ def test_join_mixed_ndim_supp(): (1, 2, 0), # Swap (0, 1, 2, "x"), # Expand ("x", 0, 1, 2), # Expand - ( - 0, - 2, - ), # Drop + (0, 2), # Drop (2, 0), # Swap and drop (2, 1, "x", 0), # Swap and expand ("x", 0, 2), # Expand and drop @@ -338,7 +335,7 @@ def test_measurable_dimshuffle(ds_order, multivariate): ref_logp = logp(base_rv, base_vv).dimshuffle(logp_ds_order) - # Disable local_dimshuffle_rv_lift to test fallback Aeppl rewrite + # Disable local_dimshuffle_rv_lift to test fallback logprob rewrite ir_rewriter = logprob_rewrites_db.query( RewriteDatabaseQuery(include=["basic"]).excluding("dimshuffle_lift") )