Skip to content

Commit 883084d

Browse files
authored
Arm backend: Fix op_index_tensor's ADD accumulator (#12056)
### Summary Add intermediate tensor for ADD accumulator in op_index_tensor.
1 parent 142b1c6 commit 883084d

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

backends/arm/operators/op_index_tensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,16 @@ def define_node(
189189
if i == 0:
190190
gather_index_name = reshaped_idxs.name
191191
else:
192+
add_idxs = tosa_graph.addIntermediate(
193+
reshaped_idxs.shape,
194+
reshaped_idxs.dtype,
195+
)
192196
tosa_graph.addOperator(
193197
ts.TosaOp.Op().ADD,
194198
[gather_index_name, reshaped_idxs.name],
195-
[gather_index_name],
199+
[add_idxs.name],
196200
)
201+
gather_index_name = add_idxs.name
197202

198203
gather_vals_shape = [N, K, C]
199204
reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype)
@@ -314,11 +319,16 @@ def define_node(
314319
if i == 0:
315320
gather_index_name = reshaped_idxs.name
316321
else:
322+
add_idxs = tosa_graph.addIntermediate(
323+
reshaped_idxs.shape,
324+
reshaped_idxs.dtype,
325+
)
317326
tosa_graph.addOperator(
318327
ts.TosaOp.Op().ADD,
319328
[gather_index_name, reshaped_idxs.name],
320-
[gather_index_name],
329+
[add_idxs.name],
321330
)
331+
gather_index_name = add_idxs.name
322332

323333
gather_vals_shape = [N, K, C]
324334
reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype)

0 commit comments

Comments
 (0)