Skip to content

Commit 12af535

Browse files
authored
Arm backend: Fix TOSA 1.0 node visitor for sum (#10908)
### Summary Fixes serialization for sum.dim_IntList node visitor as well as some rescale handling issues. ### Test plan Tested with internal and external GitHub CI. Signed-off-by: Per Åstrand <[email protected]>
1 parent 4b67dc9 commit 12af535

File tree

13 files changed

+37
-34
lines changed

13 files changed

+37
-34
lines changed

backends/arm/operators/op_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def define_node(
164164
scale_back = 1.0
165165
if inputs[0].dtype == ts.DType.INT8:
166166
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
167-
tosa_graph, inputs, node, self.tosa_specs
167+
tosa_graph, inputs, node, self.tosa_spec
168168
) # type: ignore[possibly-undefined]
169169
else:
170170
# input[0].dtype == ts.DType.INT32
@@ -192,7 +192,7 @@ def define_node(
192192
# Scale output back to 8 bit
193193
# pyre-ignore
194194
tqutils.insert_rescale_op_to_int8(
195-
tosa_graph, abs_output, scale_back, node, self.tosa_specs
195+
tosa_graph, abs_output, scale_back, node, self.tosa_spec
196196
) # type: ignore[possibly-undefined]
197197

198198

backends/arm/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def define_node(
174174
scale_back = 1.0
175175
if inputs[0].dtype == ts.DType.INT8:
176176
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
177-
tosa_graph, inputs, node, self.tosa_specs
177+
tosa_graph, inputs, node, self.tosa_spec
178178
)
179179
else:
180180
# input[0].dtype == ts.DType.INT32
@@ -202,7 +202,7 @@ def define_node(
202202
# Scale output back to 8 bit
203203
# pyre-ignore
204204
tqutils.insert_rescale_op_to_int8(
205-
tosa_graph, add_output, scale_back, node, self.tosa_specs
205+
tosa_graph, add_output, scale_back, node, self.tosa_spec
206206
) # type: ignore[possibly-undefined]
207207

208208

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def define_node(
9898
if inputs[0].dtype == ts.DType.INT8:
9999
# Rescale inputs to 32 bit
100100
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
101-
tosa_graph, inputs, node, self.tosa_specs
101+
tosa_graph, inputs, node, self.tosa_spec
102102
)
103103

104104
# Update IO

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_le.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_maximum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def define_node(
129129
)
130130

131131
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
132-
tosa_graph, inputs, node, self.tosa_specs
132+
tosa_graph, inputs, node, self.tosa_spec
133133
)
134134

135135
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -155,5 +155,5 @@ def define_node(
155155
if output.dtype == ts.DType.INT8:
156156
# insert RESCALE from int32 back to int8
157157
tqutils.insert_rescale_op_to_int8(
158-
tosa_graph, max_output, scale_back, node, self.tosa_specs
158+
tosa_graph, max_output, scale_back, node, self.tosa_spec
159159
)

backends/arm/operators/op_minimum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def define_node(
128128
)
129129

130130
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
131-
tosa_graph, inputs, node, self.tosa_specs
131+
tosa_graph, inputs, node, self.tosa_spec
132132
)
133133

134134
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -154,5 +154,5 @@ def define_node(
154154
if output.dtype == ts.DType.INT8:
155155
# insert RESCALE from int32 back to int8
156156
tqutils.insert_rescale_op_to_int8(
157-
tosa_graph, min_output, scale_back, node, self.tosa_specs
157+
tosa_graph, min_output, scale_back, node, self.tosa_spec
158158
)

backends/arm/operators/op_mul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ def define_node(
189189
input_A,
190190
input_A_qargs.zp,
191191
[1.0],
192-
tosa_spec=self.tosa_specs,
192+
tosa_spec=self.tosa_spec,
193193
)
194194
input_B_rescaled = tqutils.build_rescale_to_int32(
195195
tosa_graph,
196196
input_B,
197197
input_B_qargs.zp,
198198
[1.0],
199-
tosa_spec=self.tosa_specs,
199+
tosa_spec=self.tosa_spec,
200200
)
201201

202202
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
@@ -211,7 +211,7 @@ def define_node(
211211
)
212212
output_scale = input_A_qargs.scale * input_B_qargs.scale
213213
tqutils.insert_rescale_op_to_int8(
214-
tosa_graph, mul_output, output_scale, node, self.tosa_specs
214+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
215215
)
216216

217217

0 commit comments

Comments
 (0)