Skip to content

Commit dfdaeab

Browse files
committed
Rewrite local_merge_alloc to remove a non-strict zip
1 parent 79452d8 commit dfdaeab

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,25 +1196,23 @@ def local_merge_alloc(fgraph, node):
11961196
inputs_inner = node.inputs[0].owner.inputs
11971197
dims_outer = inputs_outer[1:]
11981198
dims_inner = inputs_inner[1:]
1199-
dims_outer_rev = dims_outer[::-1]
1200-
dims_inner_rev = dims_inner[::-1]
1199+
assert len(dims_inner) <= len(dims_outer)
12011200
# check if the pattern of broadcasting is matched, in the reversed ordering.
12021201
# The reverse ordering is needed when an Alloc add an implicit new
12031202
# broadcasted dimensions to its inputs[0]. Eg:
12041203
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
1205-
i = 0
1206-
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=False):
1207-
if dim_inner != dim_outer:
1208-
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
1209-
pass
1210-
else:
1211-
dims_outer[-1 - i] = Assert(
1212-
"You have a shape error in your graph. To see a better"
1213-
" error message and a stack trace of where in your code"
1214-
" the error is created, use the PyTensor flags"
1215-
" optimizer=None or optimizer=fast_compile."
1216-
)(dim_outer, eq(dim_outer, dim_inner))
1217-
i += 1
1204+
for i, dim_inner in enumerate(reversed(dims_inner)):
1205+
dim_outer = dims_outer[-1 - i]
1206+
if dim_inner == dim_outer:
1207+
continue
1208+
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
1209+
continue
1210+
dims_outer[-1 - i] = Assert(
1211+
"You have a shape error in your graph. To see a better"
1212+
" error message and a stack trace of where in your code"
1213+
" the error is created, use the PyTensor flags"
1214+
" optimizer=None or optimizer=fast_compile."
1215+
)(dim_outer, eq(dim_outer, dim_inner))
12181216
return [alloc(inputs_inner[0], *dims_outer)]
12191217

12201218

0 commit comments

Comments
 (0)