Skip to content

Commit 7dc059f

Browse files
Arm backend: Add docstrings for bmm and conv2d operators (#14461)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent dfa17bc commit 7dc059f

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8+
"""Provide a visitor for lowering batched matmul (BMM) to TOSA."""
9+
810
from typing import Any, List
911

1012
import torch
@@ -30,6 +32,13 @@
3032

3133
@register_node_visitor
3234
class BMMVisitor(NodeVisitor):
35+
"""Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``.
36+
37+
INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND
38+
rounding and output zero-point.
39+
40+
"""
41+
3342
target = "aten.bmm.default"
3443

3544
tosa_specs = [
@@ -47,7 +56,7 @@ def define_node(
4756
inputs: List[TosaArg],
4857
output: TosaArg,
4958
) -> None:
50-
59+
"""Define the TOSA ``MATMUL`` operator and optional rescale."""
5160
import serializer.tosa_serializer as ts # type: ignore
5261

5362
validate_num_inputs(self.target, inputs, 2)

backends/arm/operators/op_conv2d.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
8+
79
import itertools
810
from typing import Any, List
911

@@ -28,6 +30,12 @@
2830

2931
@register_node_visitor
3032
class Conv2dVisitor(NodeVisitor):
33+
"""Provide a visitor that lowers ``aten.convolution`` to TOSA.
34+
35+
Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate.
36+
37+
"""
38+
3139
target = "aten.convolution.default"
3240

3341
tosa_specs = [
@@ -38,13 +46,32 @@ class Conv2dVisitor(NodeVisitor):
3846
def __init__(self, *args):
3947
super().__init__(*args)
4048

41-
# torch.nn.Conv2d does not require the result of
42-
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
43-
# to be an integer, but tosa currently strictly require this property.
44-
# This function adjusts the pad value to meet the requirement.
4549
def adjust_pad_if_needed(
4650
self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int
4751
) -> int:
52+
"""Adjust padding to satisfy TOSA's integer output-size requirement.
53+
54+
Torch ``Conv2d`` does not require the result of
55+
``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an
56+
integer, but TOSA does. This helper reduces the provided padding so
57+
that the expression becomes divisible by ``stride``.
58+
59+
Args:
60+
input_size (int): Spatial input size along the dimension (H or W).
61+
input_weight (int): Kernel size along the same dimension.
62+
stride (int): Stride along the same dimension.
63+
pad (int): Padding value to adjust (bottom or right after duplication).
64+
dilation (int): Dilation along the same dimension.
65+
66+
Returns:
67+
int: Adjusted padding value that yields an integer output size.
68+
69+
Raises:
70+
RuntimeError: If the required adjustment exceeds the provided
71+
padding, which should be handled by the ``SizeAdjustInputPass``
72+
pass instead.
73+
74+
"""
4875
mod_remainder = (
4976
input_size + 2 * pad - dilation * (input_weight - 1) - 1
5077
) % stride
@@ -55,7 +82,8 @@ def adjust_pad_if_needed(
5582

5683
if mod_remainder > pad:
5784
raise RuntimeError(
58-
"This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
85+
"This case should be handled by the SizeAdjustInputPass pass, "
86+
"is it enabled?"
5987
)
6088
return pad - mod_remainder
6189

@@ -66,7 +94,7 @@ def define_node(
6694
inputs: List[TosaArg],
6795
output: TosaArg,
6896
) -> None:
69-
97+
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
7098
import serializer.tosa_serializer as ts # type: ignore
7199
from tosa.RoundingMode import RoundingMode # type: ignore
72100

@@ -133,7 +161,7 @@ def define_node(
133161
in_channels = input.shape[1]
134162
out_channels = weight.shape[0]
135163
if (in_channels == group.number) and (out_channels % in_channels) == 0:
136-
"""Depthwise convolution case"""
164+
"""Depthwise convolution case."""
137165
# Reshape torch shape format of weight tensor to tosa required format.
138166
# https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
139167
m_length = int(out_channels / in_channels)
@@ -178,7 +206,7 @@ def define_node(
178206
acc_type=acc_type,
179207
)
180208
else:
181-
"""Regular convolution case"""
209+
"""Regular convolution case."""
182210
tosa_op = ts.TosaOp.Op().CONV2D
183211
weight_name = weight.name
184212

0 commit comments

Comments
 (0)