Skip to content

Commit 0a4879e

Browse files
committed
changing tracked node to dot
1 parent 79655b5 commit 0a4879e

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
register_specialize,
4343
register_stabilize,
4444
)
45-
from pytensor.tensor.shape import Reshape
4645
from pytensor.tensor.slinalg import (
4746
BlockDiagonal,
4847
Cholesky,
@@ -994,11 +993,8 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
994993

995994
@register_canonicalize
996995
@register_stabilize
997-
@node_rewriter([Reshape])
996+
@node_rewriter([Dot])
998997
def rewrite_dot_kron(fgraph, node):
999-
if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot)):
1000-
return False
1001-
1002998
potential_kron = node.inputs[0].owner
1003999
if not (isinstance(potential_kron.op, KroneckerProduct)):
10041000
return False

tests/tensor/rewriting/test_linalg.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.tensor import swapaxes
1515
from pytensor.tensor.blockwise import Blockwise
1616
from pytensor.tensor.elemwise import DimShuffle
17-
from pytensor.tensor.math import _allclose, dot, matmul
17+
from pytensor.tensor.math import Dot, _allclose, dot, matmul
1818
from pytensor.tensor.nlinalg import (
1919
SVD,
2020
Det,
@@ -916,9 +916,10 @@ def test_dot_kron_rewrite():
916916
out_direct = pt.linalg.kron(a, b) @ c
917917

918918
# REWRITE TEST
919-
f_direct_rewritten = function([a, b, c], out_direct)
920-
# nodes = f_direct_rewritten.maker.fgraph.apply_nodes
921-
# Add assertion test here
919+
f_direct_rewritten = function([a, b, c], out_direct, mode="FAST_RUN")
920+
nodes = f_direct_rewritten.maker.fgraph.apply_nodes
921+
print(nodes)
922+
assert not any(isinstance(node.op.core_op, Dot) for node in nodes)
922923

923924
# NUMERIC VALUE TEST
924925
a_test = np.random.rand(m, n)

0 commit comments

Comments
 (0)