Skip to content

Commit f3b628d

Browse files
Rewrite scalar solve to division
1 parent ff98ab8 commit f3b628d

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,3 +1020,29 @@ def slogdet_specialization(fgraph, node):
10201020
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10211021
}
10221022
return replacements
1023+
1024+
1025+
@register_stabilize
1026+
@register_canonicalize
1027+
@node_rewriter([Blockwise])
1028+
def scalar_solve_to_divison(fgraph, node):
1029+
"""
1030+
Replace solve(a, b) with b / a if a is a (1, 1) matrix
1031+
"""
1032+
1033+
core_op = node.op.core_op
1034+
if not isinstance(core_op, Solve):
1035+
return None
1036+
1037+
a, b = node.inputs
1038+
old_out = node.outputs[0]
1039+
if not all(a.broadcastable[-2:]):
1040+
return None
1041+
1042+
new_out = b / a
1043+
if core_op.b_ndim == 1:
1044+
new_out = new_out.squeeze(-1)
1045+
1046+
copy_stack_trace(old_out, new_out)
1047+
1048+
return [new_out]

tests/tensor/rewriting/test_linalg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,3 +993,25 @@ def test_slogdet_specialization():
993993
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
994994
nodes = f.maker.fgraph.apply_nodes
995995
assert not any(isinstance(node.op, SLogDet) for node in nodes)
996+
997+
998+
def test_scalar_solve_to_division_rewrite():
999+
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
1000+
1001+
a = pt.dmatrix("a", shape=(1, 1))
1002+
b = pt.dvector("b")
1003+
1004+
c = pt.linalg.solve(a, b, b_ndim=1)
1005+
1006+
f = function([a, b], c, mode="FAST_RUN")
1007+
nodes = f.maker.fgraph.apply_nodes
1008+
1009+
assert not any(isinstance(node.op, Solve) for node in nodes)
1010+
1011+
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
1012+
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
1013+
1014+
c_val = np.linalg.solve(a_val, b_val)
1015+
np.testing.assert_allclose(
1016+
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
1017+
)

0 commit comments

Comments
 (0)