diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index ecdbe6e7ed..2a1a71ae40 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1046,11 +1046,15 @@ def scalar_solve_to_division(fgraph, node): if not all(a.broadcastable[-2:]): return None + if core_op.b_ndim == 1: + # Convert b to a column matrix + b = b[..., None] + # Special handling for different types of solve match core_op: case SolveTriangular(): # Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1 - new_out = b / a if not core_op.unit_diagonal else b + new_out = b / a if not core_op.unit_diagonal else pt.second(a, b) case CholeskySolve(): new_out = b / a**2 case Solve(): @@ -1061,6 +1065,7 @@ def scalar_solve_to_division(fgraph, node): ) if core_op.b_ndim == 1: + # Squeeze away the column dimension added earlier new_out = new_out.squeeze(-1) copy_stack_trace(old_out, new_out) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 539951c1d6..38f7369bcc 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -10,6 +10,7 @@ from pytensor import tensor as pt from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.graph import ancestors from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise @@ -989,34 +990,73 @@ def test_slogdet_specialization(): @pytest.mark.parametrize( - "Op, fn", + "a_batch_shape", [(), (5,)], ids=lambda x: f"a_batch_shape={x}" +) +@pytest.mark.parametrize( + "b_batch_shape", [(), (5,)], ids=lambda x: f"b_batch_shape={x}" +) +@pytest.mark.parametrize("b_ndim", (1, 2), ids=lambda x: f"b_ndim={x}") +@pytest.mark.parametrize( + "op, fn, extra_kwargs", [ - (Solve, pt.linalg.solve), - (SolveTriangular, pt.linalg.solve_triangular), - (CholeskySolve, pt.linalg.cho_solve), + (Solve, pt.linalg.solve, {}), + (SolveTriangular, pt.linalg.solve_triangular, {}), + (SolveTriangular, pt.linalg.solve_triangular, {"unit_diagonal": True}), + (CholeskySolve, pt.linalg.cho_solve, {}), ], ) -def test_scalar_solve_to_division_rewrite(Op, fn): - rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite"))) +def test_scalar_solve_to_division_rewrite( + op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape +): + def solve_op_in_graph(graph): + return any( + isinstance(var.owner.op, SolveBase) + or ( + isinstance(var.owner.op, Blockwise) + and isinstance(var.owner.op.core_op, SolveBase) + ) + for var in ancestors(graph) + if var.owner + ) + + rng = np.random.default_rng( + [ + sum(map(ord, "scalar_solve_to_division_rewrite")), + b_ndim, + *a_batch_shape, + 1, + *b_batch_shape, + ] + ) - a = pt.dmatrix("a", shape=(1, 1)) - b = pt.dvector("b") + a = pt.tensor("a", shape=(*a_batch_shape, 1, 1), dtype="float64") + b = pt.tensor("b", shape=(*b_batch_shape, *([None] * b_ndim)), dtype="float64") - if Op is CholeskySolve: + if op is CholeskySolve: # cho_solve expects a tuple (c, lower) as the first input - c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1) + c = fn((pt.linalg.cholesky(a), True), b, b_ndim=b_ndim, **extra_kwargs) else: - c = fn(a, b, b_ndim=1) + c = fn(a, b, b_ndim=b_ndim, **extra_kwargs) + assert solve_op_in_graph([c]) f = function([a, b], c, mode="FAST_RUN") - nodes = f.maker.fgraph.apply_nodes + assert not solve_op_in_graph(f.maker.fgraph.outputs) + + a_val = rng.normal(size=(*a_batch_shape, 1, 1)).astype(pytensor.config.floatX) + b_core_shape = (1, 5) if b_ndim == 2 else (1,) + b_val = rng.normal(size=(*b_batch_shape, *b_core_shape)).astype( + pytensor.config.floatX + ) - assert not any(isinstance(node.op, Op) for node in nodes) + if op is CholeskySolve: + # Avoid sign ambiguity in solve + a_val = a_val**2 - a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX) - b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX) + if extra_kwargs.get("unit_diagonal", False): + a_val = np.ones_like(a_val) - c_val = np.linalg.solve(a_val, b_val) + signature = "(n,m),(m)->(n)" if b_ndim == 1 else "(n,m),(m,k)->(n,k)" + c_val = np.vectorize(np.linalg.solve, signature=signature)(a_val, b_val) np.testing.assert_allclose( f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5 )