Skip to content

Commit efb2235

Browse files
committed
Arm Backend: Add support for Conv2d for TOSA 1.0
Signed-off-by: Saoirse Stewart <[email protected]> Change-Id: I4f4122d297416c9a65f4011af8b357532cfeb26a
1 parent a765d7e commit efb2235

File tree

2 files changed

+242
-17
lines changed

2 files changed

+242
-17
lines changed

backends/arm/operator_support/convolution_support.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import (
15+
Tosa_0_80,
16+
Tosa_1_00,
17+
TosaSpecification,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721

@@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4347

4448
# Hardware specific constraints
4549
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50+
# TODO remove this once TOSA 1.0 support for u55 is added.
51+
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52+
return False
4653
return True
4754
else:
4855
return self._is_node_supported_u55(node)

backends/arm/operators/op_conv2d.py

Lines changed: 234 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

9+
import numpy as np
910
import torch
1011

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1312
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1413
get_input_qparams,
1514
get_output_qparams,
@@ -19,14 +18,20 @@
1918
register_node_visitor,
2019
)
2120
from executorch.backends.arm.tosa_mapping import TosaArg
22-
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
21+
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
22+
from executorch.backends.arm.tosa_specification import TosaSpecification
2323
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2424

2525

2626
@register_node_visitor
27-
class Conv2dVisitor(NodeVisitor):
27+
class Conv2dVisitor_0_80(NodeVisitor):
2828
target = "aten.convolution.default"
2929

30+
tosa_specs = [
31+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
32+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
33+
]
34+
3035
def __init__(self, *args):
3136
super().__init__(*args)
3237

@@ -54,10 +59,13 @@ def adjust_pad_if_needed(
5459
def define_node(
5560
self,
5661
node: torch.fx.Node,
57-
tosa_graph: ts.TosaSerializer,
62+
tosa_graph: Any,
5863
inputs: List[TosaArg],
5964
output: TosaArg,
6065
) -> None:
66+
67+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
68+
6169
input, weight, bias, stride, pad, dilation, _, _, group = inputs
6270

6371
# Get the attributes of convolution.
@@ -170,14 +178,224 @@ def define_node(
170178
input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61]
171179
weight_scale = input_qparams[1].scale # pyre-ignore [61]
172180
output_qargs = get_output_qparams(node)
173-
build_rescale_conv_output(
174-
tosa_graph,
175-
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
176-
conv2d_res, # type: ignore[possibly-undefined]
177-
output.name,
178-
output.dtype,
179-
[input_scale],
180-
[weight_scale],
181-
[output_qargs[0].scale],
182-
output_qargs[0].zp,
181+
post_conv2d_scale = [
182+
(inp * w) / out
183+
for inp, w, out in zip(
184+
[input_scale], [weight_scale], [output_qargs[0].scale]
185+
)
186+
]
187+
188+
build_rescale_v0_80(
189+
tosa_fb=tosa_graph,
190+
scale=post_conv2d_scale,
191+
input_node=conv2d_res, # type: ignore[possibly-undefined]
192+
output_name=output.name,
193+
output_type=output.dtype,
194+
input_zp=0,
195+
output_zp=output_qargs[0].zp,
196+
per_channel=isinstance(weight_scale, torch.Tensor),
197+
) # type: ignore[call-arg]
198+
199+
200+
@register_node_visitor
201+
class Conv2dVisitor(NodeVisitor):
202+
target = "aten.convolution.default"
203+
204+
tosa_specs = [
205+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
206+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
207+
]
208+
209+
def __init__(self, *args):
210+
super().__init__(*args)
211+
212+
# torch.nn.Conv2d does not require the result of
213+
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
214+
# to be an integer, but tosa currently strictly require this property.
215+
# This function adjusts the pad value to meet the requirement.
216+
def adjust_pad_if_needed(
217+
self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int
218+
) -> int:
219+
mod_remainder = (
220+
input_size + 2 * pad - dilation * (input_weight - 1) - 1
221+
) % stride
222+
223+
# No need to adjust
224+
if mod_remainder == 0:
225+
return pad
226+
227+
if mod_remainder > pad:
228+
raise RuntimeError(
229+
"This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
230+
)
231+
return pad - mod_remainder
232+
233+
def define_node(
234+
self,
235+
node: torch.fx.Node,
236+
tosa_graph: Any,
237+
inputs: List[TosaArg],
238+
output: TosaArg,
239+
) -> None:
240+
241+
import serializer.tosa_serializer as ts # type: ignore
242+
from tosa.RoundingMode import RoundingMode # type: ignore
243+
244+
input, weight, bias, stride, pad, dilation, _, _, group = inputs
245+
246+
# Get the attributes of convolution.
247+
attr = ts.TosaSerializerAttribute()
248+
pad_attr = [val for val in pad.special for _ in (0, 1)]
249+
stride_attr = stride.special
250+
dilation_attr = dilation.special
251+
252+
# Adjust the pad value if needed to meet the
253+
# strict convolution output shape calculation.
254+
pad_attr[1] = self.adjust_pad_if_needed(
255+
input.shape[2],
256+
weight.shape[2],
257+
stride_attr[0],
258+
pad_attr[1],
259+
dilation_attr[0],
260+
)
261+
pad_attr[3] = self.adjust_pad_if_needed(
262+
input.shape[3],
263+
weight.shape[3],
264+
stride_attr[1],
265+
pad_attr[3],
266+
dilation_attr[1],
267+
)
268+
269+
input_zp = 0
270+
if inputs[0].dtype == ts.DType.INT8:
271+
# int8 input requires quantization information
272+
input_qparams = get_input_qparams(node)
273+
input_zp = input_qparams[0].zp
274+
275+
tosa_graph.addConst([1], output.dtype, [input_zp], name=f"{node.name}_input_zp")
276+
tosa_graph.addConst([1], output.dtype, [0], name=f"{node.name}_weight_zp")
277+
acc_type = (
278+
inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32
279+
)
280+
281+
# Non-bias case.
282+
if len(node.all_input_nodes) == 2:
283+
# Create a zero bias tensor if not presented
284+
out_channels = weight.shape[0]
285+
bias_name = "bias" + node.name.split("default", 1)[1]
286+
bias_type = output.dtype
287+
if output.dtype == ts.DType.INT8:
288+
# Conv is quantized to int8, but the TOSA operator has
289+
# output type int32, and the bias must be the same type
290+
# as the TOSA output type
291+
bias_type = ts.DType.INT32
292+
bias = tosa_graph.addConst(
293+
[out_channels],
294+
bias_type,
295+
[0] * out_channels,
296+
name=bias_name,
297+
)
298+
299+
# The output type is int32 when input type is int8.
300+
conv2d_output_name = output.name
301+
if output.dtype == ts.DType.INT8:
302+
conv2d_res = tosa_graph.addIntermediate(
303+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
304+
)
305+
conv2d_output_name = conv2d_res.name
306+
307+
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
308+
in_channels = input.shape[1]
309+
out_channels = weight.shape[0]
310+
if (in_channels == group.number) and (out_channels % in_channels) == 0:
311+
"""Depthwise convolution case"""
312+
# Reshape torch shape format of weight tensor to tosa required format.
313+
# https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
314+
m_length = int(out_channels / in_channels)
315+
weight_post_shape = [
316+
weight.shape[2],
317+
weight.shape[3],
318+
in_channels,
319+
m_length,
320+
]
321+
322+
weight_reshaped = tosa_graph.addIntermediate(
323+
weight_post_shape,
324+
weight.dtype,
325+
)
326+
shape = tosa_graph.addConst(
327+
np.array(weight_post_shape).shape,
328+
ts.DType.SHAPE,
329+
np.array(weight_post_shape),
330+
name=weight_reshaped.name + "_shape",
331+
)
332+
333+
attr = ts.TosaSerializerAttribute()
334+
attr.ReshapeAttribute()
335+
tosa_graph.addOperator(
336+
ts.TosaOp.Op().RESHAPE,
337+
[weight.name, shape.name],
338+
[weight_reshaped.name],
339+
attr,
340+
)
341+
342+
tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D
343+
weight_name = weight_reshaped.name
344+
345+
attr.DepthwiseConv2dAttribute(
346+
pad=pad_attr,
347+
stride=stride_attr,
348+
dilation=dilation_attr,
349+
local_bound=False,
350+
acc_type=acc_type,
351+
)
352+
else:
353+
"""Regular convolution case"""
354+
tosa_op = ts.TosaOp.Op().CONV2D
355+
weight_name = weight.name
356+
357+
attr.Conv2dAttribute(
358+
pad=pad_attr,
359+
stride=stride_attr,
360+
dilation=dilation_attr,
361+
local_bound=False,
362+
acc_type=acc_type,
363+
)
364+
365+
tosa_graph.addOperator(
366+
tosa_op,
367+
[
368+
input.name,
369+
weight_name,
370+
bias.name,
371+
f"{node.name}_input_zp",
372+
f"{node.name}_weight_zp",
373+
],
374+
[conv2d_output_name],
375+
attr,
376+
)
377+
378+
# For quantized convolution, rescale the output value back to the same
379+
# integer value domain of the next op. Otherwise return float32 output.
380+
if inputs[0].dtype == ts.DType.INT8:
381+
# Get scale_factor from input, weight, and output.
382+
input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61]
383+
weight_scale = input_qparams[1].scale # pyre-ignore [61]
384+
output_qargs = get_output_qparams(node)
385+
post_conv2d_scale = [
386+
(inp * w) / out
387+
for inp, w, out in zip(
388+
[input_scale], [weight_scale], [output_qargs[0].scale]
389+
)
390+
]
391+
build_rescale(
392+
tosa_fb=tosa_graph,
393+
scale=post_conv2d_scale,
394+
input_node=conv2d_res, # type: ignore[possibly-undefined]
395+
output_name=output.name,
396+
output_type=output.dtype,
397+
input_zp=0,
398+
output_zp=output_qargs[0].zp,
399+
per_channel=isinstance(weight_scale, torch.Tensor),
400+
rounding_mode=RoundingMode.SINGLE_ROUND,
183401
)

0 commit comments

Comments
 (0)