diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 8367642c4c..b0fd806d11 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node): @node_rewriter([Blockwise]) def rewrite_inv_inv(fgraph, node): """ - This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. + This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), + we get back our original input without having to compute inverse once. - Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. + Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to + be simply rewritten. Parameters ---------- @@ -855,6 +857,83 @@ def rewrite_det_kronecker(fgraph, node): return [det_final] +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@node_rewriter([Blockwise]) +def rewrite_solve_kron_to_solve(fgraph, node): + """ + Given a linear system of the form: + + .. math:: + + (A \\otimes B) x = y + + Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further, + define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as: + + .. math:: + + x = \text{vec}(B^{-1} Y A^{-T}) + + Eliminating the kronecker product from the expression. + """ + + if not isinstance(node.op.core_op, SolveBase): + return + + solve_op = node.op + props_dict = solve_op.core_op._props_dict() + b_ndim = props_dict["b_ndim"] + + A, b = node.inputs + + if not A.owner or not ( + isinstance(A.owner.op, KroneckerProduct) + or isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, KroneckerProduct) + ): + return + + x1, x2 = A.owner.inputs + + # If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid. + # We will proceed if they are unknown, but this makes the rewrite shape unsafe. + x1_core_shapes = x1.type.shape[-2:] + x2_core_shapes = x2.type.shape[-2:] + + if ( + all(shape is not None for shape in x1_core_shapes) + and x1_core_shapes[-1] != x1_core_shapes[-2] + ) or ( + all(shape is not None for shape in x2_core_shapes) + and x2_core_shapes[-1] != x2_core_shapes[-2] + ): + return None + + m, n = x1.shape[-2], x2.shape[-2] + batch_shapes = x1.shape[:-2] + + if b_ndim == 1: + # The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new + # solve node with n_ndim = 2 + props_dict["b_ndim"] = 2 + new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict)) + B = b.reshape((*batch_shapes, m, n)) + res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1)) + + else: + # If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional + # batch dimension + b_batch = b.shape[-1] + B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n)) + + res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape( + (*batch_shapes, -1, b_batch) + ) + + return [res] + + @register_canonicalize @register_stabilize @node_rewriter([Blockwise]) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 38f7369bcc..995c6568d5 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -828,6 +828,156 @@ def test_slogdet_kronecker_rewrite(): ) +def count_kron_ops(fgraph): + return sum( + [ + isinstance(node.op, KroneckerProduct) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, KroneckerProduct) + ) + for node in fgraph.apply_nodes + ] + ) + + +@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"]) +@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"]) +@pytest.mark.parametrize( + "solve_op, solve_kwargs", + [ + (pt.linalg.solve, {"assume_a": "gen"}), + (pt.linalg.solve, {"assume_a": "pos"}), + (pt.linalg.solve, {"assume_a": "upper triangular"}), + ], + ids=["general", "positive definite", "triangular"], +) +def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs): + # A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise + # the rewrite is invalid. + a_shape = (3, 3) if not add_batch else (2, 3, 3) + b_shape = (2, 2) if not add_batch else (2, 2, 2) + A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape) + + m, n = a_shape[-2], b_shape[-2] + y_shape = (m * n,) + if b_ndim == 2: + y_shape = (m * n, 3) + if add_batch: + y_shape = (2, *y_shape) + + y = pt.tensor("y", shape=y_shape) + C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B) + + x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim) + + fn_expected = pytensor.function( + [A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve") + ) + assert count_kron_ops(fn_expected.maker.fgraph) == 1 + + fn = pytensor.function([A, B, y], x) + assert count_kron_ops(fn.maker.fgraph) == 0 + + rng = np.random.default_rng(sum(map(ord, "Go away Kron!"))) + a_val = rng.normal(size=a_shape) + b_val = rng.normal(size=b_shape) + y_val = rng.normal(size=y_shape) + + if solve_kwargs["assume_a"] == "pos": + a_val = a_val @ np.moveaxis(a_val, -2, -1) + b_val = b_val @ np.moveaxis(b_val, -2, -1) + elif solve_kwargs["assume_a"] == "upper triangular": + a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1) + b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1) + + if len(a_shape) > 2: + a_idx = (slice(None, None), *a_idx) + if len(b_shape) > 2: + b_idx = (slice(None, None), *b_idx) + + a_val[a_idx] = 0 + b_val[b_idx] = 0 + + a_val = a_val.astype(config.floatX) + b_val = b_val.astype(config.floatX) + y_val = y_val.astype(config.floatX) + + expected = fn_expected(a_val, b_val, y_val) + result = fn(a_val, b_val, y_val) + + if config.floatX == "float64": + tol = 1e-8 + elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos": + tol = 1e-4 + else: + # Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a + # good theory of why. Skipping this case would also be an option. + tol = 1e-2 + + np.testing.assert_allclose( + expected, + result, + atol=tol, + rtol=tol, + ) + + +def test_rewrite_solve_kron_to_solve_not_applied(): + # Check that the rewrite is not applied when the component matrices to the kron are static and not square + A = pt.tensor("A", shape=(3, 2)) + B = pt.tensor("B", shape=(2, 3)) + C = pt.linalg.kron(A, B) + + y = pt.vector("y", shape=(6,)) + x = pt.linalg.solve(C, y) + + fn = pytensor.function([A, B, y], x) + + assert count_kron_ops(fn.maker.fgraph) == 1 + + # If shapes are static, it should always be applied + A = pt.tensor("A", shape=(3, None, None)) + B = pt.tensor("B", shape=(3, None, None)) + C = pt.linalg.kron(A, B) + y = pt.tensor("y", shape=(None,)) + x = pt.linalg.solve(C, y) + fn = pytensor.function([A, B, y], x) + + assert count_kron_ops(fn.maker.fgraph) == 0 + + +@pytest.mark.parametrize( + "a_shape, b_shape", + [((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))], + ids=["small", "medium", "large"], +) +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark): + A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape) + C = pt.linalg.kron(A, B) + + m, n = a_shape[-2], b_shape[-2] + has_batch = len(a_shape) == 3 + y_shape = (a_shape[0], m * n) if has_batch else (m * n,) + y = pt.tensor("y", shape=y_shape) + x = pt.linalg.solve(C, y, b_ndim=1) + + rng = np.random.default_rng(sum(map(ord, "Go away Kron!"))) + a_val = rng.normal(size=a_shape).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) + y_val = rng.normal(size=y_shape).astype(config.floatX) + + mode = ( + get_default_mode() + if rewrite + else get_default_mode().excluding("rewrite_solve_kron_to_solve") + ) + + fn = pytensor.function([A, B, y], x, mode=mode) + benchmark(fn, a_val, b_val, y_val) + + def test_cholesky_eye_rewrite(): x = pt.eye(10) L = pt.linalg.cholesky(x)