Skip to content

Commit 79655b5

Browse files
committed
Added dot kron rewrite
1 parent 33a4d48 commit 79655b5

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
register_specialize,
4343
register_stabilize,
4444
)
45+
from pytensor.tensor.shape import Reshape
4546
from pytensor.tensor.slinalg import (
4647
BlockDiagonal,
4748
Cholesky,
@@ -989,3 +990,23 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
989990
"jax",
990991
position=0.9, # Run before canonicalization
991992
)
993+
994+
995+
@register_canonicalize
996+
@register_stabilize
997+
@node_rewriter([Reshape])
998+
def rewrite_dot_kron(fgraph, node):
999+
if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot)):
1000+
return False
1001+
1002+
potential_kron = node.inputs[0].owner
1003+
if not (isinstance(potential_kron.op, KroneckerProduct)):
1004+
return False
1005+
1006+
c = node.inputs[1]
1007+
[a, b] = potential_kron.inputs
1008+
1009+
m, n = a.type.shape
1010+
p, q = b.type.shape
1011+
out_clever = (b @ c.reshape(shape=(n, q)).T @ a.T).ravel()
1012+
return [out_clever]

tests/tensor/rewriting/test_linalg.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,3 +906,29 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
906906
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
907907
nodes = f_rewritten.maker.fgraph.apply_nodes
908908
assert any(isinstance(node.op, Cholesky) for node in nodes)
909+
910+
911+
def test_dot_kron_rewrite():
912+
m, n, p, q = 3, 4, 6, 7
913+
a = pt.matrix("a", shape=(m, n))
914+
b = pt.matrix("b", shape=(p, q))
915+
c = pt.matrix("c", shape=(n * q, 1))
916+
out_direct = pt.linalg.kron(a, b) @ c
917+
918+
# REWRITE TEST
919+
f_direct_rewritten = function([a, b, c], out_direct)
920+
# nodes = f_direct_rewritten.maker.fgraph.apply_nodes
921+
# Add assertion test here
922+
923+
# NUMERIC VALUE TEST
924+
a_test = np.random.rand(m, n)
925+
b_test = np.random.rand(p, q)
926+
c_test = np.random.rand(n * q, 1)
927+
out_direct_val = np.kron(a_test, b_test) @ c_test
928+
out_clever_val = f_direct_rewritten(a_test, b_test, c_test)
929+
assert_allclose(
930+
out_direct_val,
931+
out_clever_val,
932+
atol=1e-3 if config.floatX == "float32" else 1e-8,
933+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
934+
)

0 commit comments

Comments
 (0)