Skip to content

Commit 5e64f0a

Browse files
committed
Use python implementation for constant_folding Ops
1 parent 0f65bb7 commit 5e64f0a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pytensor.compile.ops import ViewOp
3232
from pytensor.graph import FunctionGraph
3333
from pytensor.graph.basic import Constant
34+
from pytensor.graph.op import _NoPythonOp
3435
from pytensor.graph.rewriting.basic import (
3536
NodeProcessingGraphRewriter,
3637
NodeRewriter,
@@ -1108,7 +1109,12 @@ def unconditional_constant_folding(fgraph, node):
11081109
storage_map[o] = [None]
11091110
compute_map[o] = [False]
11101111

1111-
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1112+
if isinstance(node.op, _NoPythonOp):
1113+
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1114+
else:
1115+
thunk = node.op.make_thunk(
1116+
node, storage_map, compute_map, no_recycling=[], impl="py"
1117+
)
11121118
required = thunk()
11131119

11141120
# A node whose inputs are all provided should always return successfully

0 commit comments

Comments
 (0)