@@ -1207,25 +1207,23 @@ def local_merge_alloc(fgraph, node):
1207
1207
inputs_inner = node .inputs [0 ].owner .inputs
1208
1208
dims_outer = inputs_outer [1 :]
1209
1209
dims_inner = inputs_inner [1 :]
1210
- dims_outer_rev = dims_outer [::- 1 ]
1211
- dims_inner_rev = dims_inner [::- 1 ]
1210
+ assert len (dims_inner ) <= len (dims_outer )
1212
1211
# check if the pattern of broadcasting is matched, in the reversed ordering.
1213
1212
# The reverse ordering is needed when an Alloc add an implicit new
1214
1213
# broadcasted dimensions to its inputs[0]. Eg:
1215
1214
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
1216
- i = 0
1217
- for dim_inner , dim_outer in zip (dims_inner_rev , dims_outer_rev , strict = False ):
1218
- if dim_inner != dim_outer :
1219
- if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1220
- pass
1221
- else :
1222
- dims_outer [- 1 - i ] = Assert (
1223
- "You have a shape error in your graph. To see a better"
1224
- " error message and a stack trace of where in your code"
1225
- " the error is created, use the PyTensor flags"
1226
- " optimizer=None or optimizer=fast_compile."
1227
- )(dim_outer , eq (dim_outer , dim_inner ))
1228
- i += 1
1215
+ for i , dim_inner in enumerate (reversed (dims_inner )):
1216
+ dim_outer = dims_outer [- 1 - i ]
1217
+ if dim_inner == dim_outer :
1218
+ continue
1219
+ if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1220
+ continue
1221
+ dims_outer [- 1 - i ] = Assert (
1222
+ "You have a shape error in your graph. To see a better"
1223
+ " error message and a stack trace of where in your code"
1224
+ " the error is created, use the PyTensor flags"
1225
+ " optimizer=None or optimizer=fast_compile."
1226
+ )(dim_outer , eq (dim_outer , dim_inner ))
1229
1227
return [alloc (inputs_inner [0 ], * dims_outer )]
1230
1228
1231
1229
0 commit comments