Skip to content

Commit 13b20e2

Browse files
committed
Added rewrite for diag of kronecker product
1 parent 1a1c62b commit 13b20e2

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
from pytensor.scalar.basic import Mul
1313
from pytensor.tensor.basic import (
1414
AllocDiag,
15+
ExtractDiag,
1516
Eye,
1617
TensorVariable,
18+
diag,
1719
diagonal,
1820
)
1921
from pytensor.tensor.blas import Dot22
2022
from pytensor.tensor.blockwise import Blockwise
2123
from pytensor.tensor.elemwise import DimShuffle, Elemwise
22-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
24+
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
2325
from pytensor.tensor.nlinalg import (
2426
SVD,
2527
KroneckerProduct,
@@ -701,3 +703,20 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
701703
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
702704

703705
return [eye_input / non_eye_input]
706+
707+
708+
@register_canonicalize
709+
@register_stabilize
710+
@node_rewriter([ExtractDiag])
711+
def rewrite_diag_kronecker(fgraph, node):
712+
# Check for inner kron operation
713+
potential_kron = node.inputs[0].owner
714+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
715+
return None
716+
717+
# Find the matrices
718+
a, b = potential_kron.inputs
719+
diag_a, diag_b = diag(a), diag(b)
720+
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
721+
722+
return [outer_prod_as_vector]

tests/tensor/rewriting/test_linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,26 @@ def test_inv_diag_from_diag(inv_op):
662662
atol=ATOL,
663663
rtol=RTOL,
664664
)
665+
666+
667+
def test_diag_kronecker_rewrite():
668+
a, b = pt.dmatrices("a", "b")
669+
kron_prod = pt.linalg.kron(a, b)
670+
diag_kron_prod = pt.diag(kron_prod)
671+
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
672+
673+
# Rewrite Test
674+
nodes = f_rewritten.maker.fgraph.apply_nodes
675+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
676+
677+
# Value Test
678+
a_test, b_test = np.random.rand(2, 20, 20)
679+
kron_prod_test = np.kron(a_test, b_test)
680+
diag_kron_prod_test = np.diag(kron_prod_test)
681+
rewritten_val = f_rewritten(a_test, b_test)
682+
assert_allclose(
683+
diag_kron_prod_test,
684+
rewritten_val,
685+
atol=1e-3 if config.floatX == "float32" else 1e-8,
686+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
687+
)

0 commit comments

Comments
 (0)