@@ -1207,25 +1207,23 @@ def local_merge_alloc(fgraph, node):
12071207 inputs_inner = node .inputs [0 ].owner .inputs
12081208 dims_outer = inputs_outer [1 :]
12091209 dims_inner = inputs_inner [1 :]
1210- dims_outer_rev = dims_outer [::- 1 ]
1211- dims_inner_rev = dims_inner [::- 1 ]
1210+ assert len (dims_inner ) <= len (dims_outer )
12121211 # check if the pattern of broadcasting is matched, in the reversed ordering.
12131212 # The reverse ordering is needed when an Alloc add an implicit new
12141213 # broadcasted dimensions to its inputs[0]. Eg:
12151214 # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
1216- i = 0
1217- for dim_inner , dim_outer in zip (dims_inner_rev , dims_outer_rev , strict = False ):
1218- if dim_inner != dim_outer :
1219- if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1220- pass
1221- else :
1222- dims_outer [- 1 - i ] = Assert (
1223- "You have a shape error in your graph. To see a better"
1224- " error message and a stack trace of where in your code"
1225- " the error is created, use the PyTensor flags"
1226- " optimizer=None or optimizer=fast_compile."
1227- )(dim_outer , eq (dim_outer , dim_inner ))
1228- i += 1
1215+ for i , dim_inner in enumerate (reversed (dims_inner )):
1216+ dim_outer = dims_outer [- 1 - i ]
1217+ if dim_inner == dim_outer :
1218+ continue
1219+ if isinstance (dim_inner , Constant ) and dim_inner .data == 1 :
1220+ continue
1221+ dims_outer [- 1 - i ] = Assert (
1222+ "You have a shape error in your graph. To see a better"
1223+ " error message and a stack trace of where in your code"
1224+ " the error is created, use the PyTensor flags"
1225+ " optimizer=None or optimizer=fast_compile."
1226+ )(dim_outer , eq (dim_outer , dim_inner ))
12291227 return [alloc (inputs_inner [0 ], * dims_outer )]
12301228
12311229
0 commit comments