Skip to content

Commit 344a43c

Browse files
committed
Rewrite local_merge_alloc to remove a non-strict zip
1 parent b2ebdd2 commit 344a43c

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,25 +1207,23 @@ def local_merge_alloc(fgraph, node):
12071207
inputs_inner = node.inputs[0].owner.inputs
12081208
dims_outer = inputs_outer[1:]
12091209
dims_inner = inputs_inner[1:]
1210-
dims_outer_rev = dims_outer[::-1]
1211-
dims_inner_rev = dims_inner[::-1]
1210+
assert len(dims_inner) <= len(dims_outer)
12121211
# check if the pattern of broadcasting is matched, in the reversed ordering.
12131212
# The reverse ordering is needed when an Alloc add an implicit new
12141213
# broadcasted dimensions to its inputs[0]. Eg:
12151214
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
1216-
i = 0
1217-
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=False):
1218-
if dim_inner != dim_outer:
1219-
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
1220-
pass
1221-
else:
1222-
dims_outer[-1 - i] = Assert(
1223-
"You have a shape error in your graph. To see a better"
1224-
" error message and a stack trace of where in your code"
1225-
" the error is created, use the PyTensor flags"
1226-
" optimizer=None or optimizer=fast_compile."
1227-
)(dim_outer, eq(dim_outer, dim_inner))
1228-
i += 1
1215+
for i, dim_inner in enumerate(reversed(dims_inner)):
1216+
dim_outer = dims_outer[-1 - i]
1217+
if dim_inner == dim_outer:
1218+
continue
1219+
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
1220+
continue
1221+
dims_outer[-1 - i] = Assert(
1222+
"You have a shape error in your graph. To see a better"
1223+
" error message and a stack trace of where in your code"
1224+
" the error is created, use the PyTensor flags"
1225+
" optimizer=None or optimizer=fast_compile."
1226+
)(dim_outer, eq(dim_outer, dim_inner))
12291227
return [alloc(inputs_inner[0], *dims_outer)]
12301228

12311229

0 commit comments

Comments
 (0)