Skip to content

Commit 57773ff

Browse files
committed
Arm backend: Use output.name in node visitors
As mentioned in #15381, TOSA tensors need unique naming, which gets tricky with submodules. It is handled in the TosaArg object, and therefore node visitors need to use output.name rather than node.name when creating new tensors. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I7a943deda0888c1de8796dd573e8befda3f074b2
1 parent bde6b11 commit 57773ff

File tree

8 files changed

+24
-27
lines changed

8 files changed

+24
-27
lines changed

backends/arm/operators/op_index_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,14 @@ def define_node(
165165
# channels and thus the stride-shift.
166166
data = np.full(index_shape, int(values_strides[i] / C))
167167
mul_const = tosa_graph.addConst(index_shape, index_dtype, data)
168-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift")
168+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift")
169169
attr = ts.TosaSerializerAttribute()
170170
attr.MulAttribute()
171171
self._serialize_operator(
172172
node,
173173
tosa_graph,
174174
ts.Op.MUL,
175-
[index_name, mul_const.name, f"{node.name}_{i}_shift"],
175+
[index_name, mul_const.name, f"{output.name}_{i}_shift"],
176176
[stride_shifted_indices.name],
177177
attr,
178178
)
@@ -186,7 +186,7 @@ def define_node(
186186
stride_shifted_indices.name,
187187
gather_idx_shape,
188188
reshaped_idxs.name,
189-
shape_name_override=f"{node.name}_{i}_shape",
189+
shape_name_override=f"{output.name}_{i}_shape",
190190
)
191191

192192
# Guarantees that the accumulation tensor is properly
@@ -218,7 +218,7 @@ def define_node(
218218
values.name,
219219
gather_vals_shape,
220220
reshaped_input.name,
221-
shape_name_override=f"{node.name}_index_shape",
221+
shape_name_override=f"{output.name}_index_shape",
222222
)
223223

224224
gather_out_shape = (N, W, C)
@@ -244,5 +244,5 @@ def define_node(
244244
gather_out.name,
245245
list(output_shape),
246246
output.name,
247-
shape_name_override=f"{node.name}_output_shape",
247+
shape_name_override=f"{output.name}_output_shape",
248248
)

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def define_node(
4848
output.tosa_spec,
4949
)
5050

51-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
51+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift")
5252
attr = ts.TosaSerializerAttribute()
5353
attr.MulAttribute()
5454
self._serialize_operator(
5555
node,
5656
tosa_graph,
5757
ts.Op.MUL,
58-
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
58+
[inputs[0].name, inputs[1].name, f"{output.name}_shift"],
5959
[output.name],
6060
attr,
6161
)

backends/arm/operators/op_repeat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def define_node(
5656
(len(multiples),),
5757
ts.DType.SHAPE,
5858
list(tosa_shape(multiples, output.dim_order)),
59-
name=node.name + "_multiples",
59+
name=output.name + "_multiples",
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_slice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def define_node(
120120
(starts_len,),
121121
ts.DType.SHAPE,
122122
starts,
123-
node.name + "_start_shape",
123+
output.name + "_start_shape",
124124
)
125125

126126
sizes = [size if i == dim else shape[i] for i in input_node.dim_order]
@@ -130,7 +130,7 @@ def define_node(
130130
sizes_len = 1
131131
sizes = [0]
132132
sizes_tensor = tosa_graph.addConst(
133-
(sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape"
133+
(sizes_len,), ts.DType.SHAPE, sizes, output.name + "_sizes_shape"
134134
)
135135

136136
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_tosa_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def define_node(
7272
else:
7373
input0_zp, input1_zp = 0, 0
7474

75-
input_A_ZP_name = f"{node.name}_A_ZP"
76-
input_B_ZP_name = f"{node.name}_B_ZP"
75+
input_A_ZP_name = f"{output.name}_A_ZP"
76+
input_B_ZP_name = f"{output.name}_B_ZP"
7777
tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name)
7878
tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name)
7979

backends/arm/operators/op_tosa_resize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def in_int16_range(x):
8484
scale_d_vals[1],
8585
]
8686
scales_tensor = tosa_graph.addConst(
87-
[len(scales)], ts.DType.SHAPE, scales, node.name + "_scales"
87+
[len(scales)], ts.DType.SHAPE, scales, output.name + "_scales"
8888
)
8989
offset = [int(v) for v in offset_yx.tolist()]
9090
offset_tensor = tosa_graph.addConst(
91-
[len(offset)], ts.DType.SHAPE, offset, node.name + "_offset"
91+
[len(offset)], ts.DType.SHAPE, offset, output.name + "_offset"
9292
)
9393
border = [int(v) for v in border_yx.tolist()]
9494
border_tensor = tosa_graph.addConst(
95-
[len(border)], ts.DType.SHAPE, border, node.name + "_border"
95+
[len(border)], ts.DType.SHAPE, border, output.name + "_border"
9696
)
9797
attr = ts.TosaSerializerAttribute()
9898
attr.ResizeAttribute(resize_mode)

backends/arm/operators/op_tosa_table.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,27 +44,24 @@ def define_node(
4444
if inputs[0].dtype == ts.DType.INT16:
4545
validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec)
4646

47-
if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr]
47+
# The name of the table constant is a bit complex.
48+
# The name of the pytorch buffer will be the target of last node argument.
49+
# However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus
50+
# needs to be taken from the last TosaArg.
51+
pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr]
52+
tosa_table_buffer_name = inputs[-1].name
53+
if pytorch_table_buffer_name not in self._exported_program.state_dict.keys():
4854
raise RuntimeError(
4955
f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}."
5056
)
5157

52-
table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr]
53-
54-
table_tensor_name = node.name + "_table"
55-
tosa_graph.addConst(
56-
table.shape,
57-
ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16,
58-
table.detach().numpy(),
59-
name=table_tensor_name,
60-
)
6158
attr = ts.TosaSerializerAttribute()
6259
attr.TableAttribute()
6360
self._serialize_operator(
6461
node,
6562
tosa_graph,
6663
ts.Op.TABLE,
67-
[inputs[0].name, table_tensor_name],
64+
[inputs[0].name, tosa_table_buffer_name],
6865
[output.name],
6966
attr,
7067
)

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def define_node(
6666
shape_len,
6767
ts.DType.SHAPE,
6868
shape_data,
69-
name=node.name + "_shape",
69+
name=output.name + "_shape",
7070
)
7171

7272
attr = ts.TosaSerializerAttribute()

0 commit comments

Comments
 (0)