@@ -1196,25 +1196,23 @@ def local_merge_alloc(fgraph, node):
1196
1196
inputs_inner = node .inputs [0 ].owner .inputs
1197
1197
dims_outer = inputs_outer [1 :]
1198
1198
dims_inner = inputs_inner [1 :]
1199
- dims_outer_rev = dims_outer [::- 1 ]
1200
- dims_inner_rev = dims_inner [::- 1 ]
1199
+ assert len (dims_inner ) <= len (dims_outer )
1201
1200
# check if the pattern of broadcasting is matched, in the reversed ordering.
1202
1201
# The reverse ordering is needed when an Alloc add an implicit new
1203
1202
# broadcasted dimensions to its inputs[0]. Eg:
1204
1203
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
1205
- i = 0
1206
- for dim_inner , dim_outer in zip (dims_inner_rev , dims_outer_rev , strict = False ):
1207
- if dim_inner != dim_outer :
1208
- if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1209
- pass
1210
- else :
1211
- dims_outer [- 1 - i ] = Assert (
1212
- "You have a shape error in your graph. To see a better"
1213
- " error message and a stack trace of where in your code"
1214
- " the error is created, use the PyTensor flags"
1215
- " optimizer=None or optimizer=fast_compile."
1216
- )(dim_outer , eq (dim_outer , dim_inner ))
1217
- i += 1
1204
+ for i , dim_inner in enumerate (reversed (dims_inner )):
1205
+ dim_outer = dims_outer [- 1 - i ]
1206
+ if dim_inner == dim_outer :
1207
+ continue
1208
+ if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1209
+ continue
1210
+ dims_outer [- 1 - i ] = Assert (
1211
+ "You have a shape error in your graph. To see a better"
1212
+ " error message and a stack trace of where in your code"
1213
+ " the error is created, use the PyTensor flags"
1214
+ " optimizer=None or optimizer=fast_compile."
1215
+ )(dim_outer , eq (dim_outer , dim_inner ))
1218
1216
return [alloc (inputs_inner [0 ], * dims_outer )]
1219
1217
1220
1218
0 commit comments