Skip to content

Commit b9e35cf

Browse files
Add rewrite to move transpose operations to BLAS
1 parent 2d75e41 commit b9e35cf

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,8 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
957957
@node_rewriter([det])
958958
def slogdet_specialization(fgraph, node):
959959
"""
960-
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
960+
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites
961+
them using the SLogDet operation.
961962
962963
Parameters
963964
----------
@@ -1013,3 +1014,30 @@ def slogdet_specialization(fgraph, node):
10131014
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10141015
}
10151016
return replacements
1017+
1018+
1019+
@register_specialize
1020+
@node_rewriter([Blockwise])
1021+
def rewrite_A_transposed_solve_to_transposed_solver(fgraph, node):
1022+
"""
1023+
Replace solve(A.T, b) with solve(A, b, transposed=True).
1024+
"""
1025+
solve_op = node.op.core_op
1026+
if not isinstance(solve_op, Solve):
1027+
return None
1028+
1029+
A, b = node.inputs
1030+
1031+
if not is_matrix_transpose(A):
1032+
return None
1033+
1034+
# If we've gotten here, A is actually A.T
1035+
A = A.owner.inputs[0]
1036+
solve_kwargs = solve_op._props_dict()
1037+
1038+
# The reason for using `not transposed` here is that we could have had solve(A.T, b, transposed=True), in which
1039+
# case we can just do solve(A, b, transposed=False) and the rewrite is still valid. In the "base case" we had
1040+
# solve(A.T, b, transposed=False) and we're going to solve(A, b, tranposed=True).
1041+
solve_kwargs["transposed"] = not solve_kwargs["transposed"]
1042+
1043+
return [Blockwise(Solve(**solve_kwargs))(A, b)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph import FunctionGraph
1314
from pytensor.graph.rewriting.utils import rewrite_graph
1415
from pytensor.tensor import swapaxes
1516
from pytensor.tensor.blockwise import Blockwise
@@ -993,3 +994,23 @@ def test_slogdet_specialization():
993994
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
994995
nodes = f.maker.fgraph.apply_nodes
995996
assert not any(isinstance(node.op, SLogDet) for node in nodes)
997+
998+
999+
def test_rewrite_A_transposed_solve_to_transposed_solver():
1000+
A = matrix("A")
1001+
b = vector("b")
1002+
x = pt.linalg.solve(A.T, b)
1003+
1004+
fg = FunctionGraph([A, b], [x])
1005+
assert any(isinstance(node.op, DimShuffle) for node in fg.toposort())
1006+
1007+
f = function([A, b], x, mode="FAST_RUN")
1008+
assert not any(
1009+
isinstance(node.op, DimShuffle) for node in f.maker.fgraph.toposort()
1010+
)
1011+
1012+
A_val = np.random.normal(size=(10, 10))
1013+
b_val = np.random.normal(size=(10,))
1014+
1015+
g = function([A, b], pt.linalg.solve(A.T, b), mode="FAST_COMPILE")
1016+
np.testing.assert_allclose(f(A_val, b_val), g(A_val, b_val))

0 commit comments

Comments
 (0)