Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,11 +1046,15 @@ def scalar_solve_to_division(fgraph, node):
if not all(a.broadcastable[-2:]):
return None

if core_op.b_ndim == 1:
# Convert b to a column matrix
b = b[..., None]
Comment on lines +1049 to +1051
Copy link
Member Author

@ricardoV94 ricardoV94 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a had shape (1, 1, 1) and b had shape (5, 1) and b_ndim=1 (meaning 5 is a batch dimension for b), we were returning something with the wrong shape as it became (1, 5, 1), after broadcasting with A, and then (1, 5) after squeezing below


# Special handling for different types of solve
match core_op:
case SolveTriangular():
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
new_out = b / a if not core_op.unit_diagonal else pt.second(a, b)
case CholeskySolve():
new_out = b / a**2
case Solve():
Expand All @@ -1061,6 +1065,7 @@ def scalar_solve_to_division(fgraph, node):
)

if core_op.b_ndim == 1:
# Squeeze away the column dimension added earlier
new_out = new_out.squeeze(-1)

copy_stack_trace(old_out, new_out)
Expand Down
72 changes: 56 additions & 16 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -989,34 +990,73 @@ def test_slogdet_specialization():


@pytest.mark.parametrize(
"Op, fn",
"a_batch_shape", [(), (5,)], ids=lambda x: f"a_batch_shape={x}"
)
@pytest.mark.parametrize(
"b_batch_shape", [(), (5,)], ids=lambda x: f"b_batch_shape={x}"
)
@pytest.mark.parametrize("b_ndim", (1, 2), ids=lambda x: f"b_ndim={x}")
@pytest.mark.parametrize(
"op, fn, extra_kwargs",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
(Solve, pt.linalg.solve, {}),
(SolveTriangular, pt.linalg.solve_triangular, {}),
(SolveTriangular, pt.linalg.solve_triangular, {"unit_diagonal": True}),
(CholeskySolve, pt.linalg.cho_solve, {}),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
def test_scalar_solve_to_division_rewrite(
op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape
):
def solve_op_in_graph(graph):
return any(
isinstance(var.owner.op, SolveBase)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, SolveBase)
)
for var in ancestors(graph)
if var.owner
)

rng = np.random.default_rng(
[
sum(map(ord, "scalar_solve_to_division_rewrite")),
b_ndim,
*a_batch_shape,
1,
*b_batch_shape,
]
)

a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")
a = pt.tensor("a", shape=(*a_batch_shape, 1, 1), dtype="float64")
b = pt.tensor("b", shape=(*b_batch_shape, *([None] * b_ndim)), dtype="float64")

if Op is CholeskySolve:
if op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=b_ndim, **extra_kwargs)
else:
c = fn(a, b, b_ndim=1)
c = fn(a, b, b_ndim=b_ndim, **extra_kwargs)

assert solve_op_in_graph([c])
f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not solve_op_in_graph(f.maker.fgraph.outputs)

a_val = rng.normal(size=(*a_batch_shape, 1, 1)).astype(pytensor.config.floatX)
b_core_shape = (1, 5) if b_ndim == 2 else (1,)
b_val = rng.normal(size=(*b_batch_shape, *b_core_shape)).astype(
pytensor.config.floatX
)

assert not any(isinstance(node.op, Op) for node in nodes)
if op is CholeskySolve:
# Avoid sign ambiguity in solve
a_val = a_val**2

a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
if extra_kwargs.get("unit_diagonal", False):
a_val = np.ones_like(a_val)

c_val = np.linalg.solve(a_val, b_val)
signature = "(n,m),(m)->(n)" if b_ndim == 1 else "(n,m),(m,k)->(n,k)"
c_val = np.vectorize(np.linalg.solve, signature=signature)(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)