Skip to content

Commit 1527af2

Browse files
committed
Use python implementation for constant_folding Ops
1 parent 55f5abb commit 1527af2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,8 +1106,15 @@ def unconditional_constant_folding(fgraph, node):
11061106
storage_map[o] = [None]
11071107
compute_map[o] = [False]
11081108

1109-
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1110-
required = thunk()
1109+
try:
1110+
thunk = node.op.make_thunk(
1111+
node, storage_map, compute_map, no_recycling=[], impl="py"
1112+
)
1113+
required = thunk()
1114+
except NotImplementedError:
1115+
# Not all Ops have a python implementation
1116+
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1117+
required = thunk()
11111118

11121119
# A node whose inputs are all provided should always return successfully
11131120
assert not required

0 commit comments

Comments
 (0)