4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
+ """Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
8
+
7
9
import itertools
8
10
from typing import Any , List
9
11
28
30
29
31
@register_node_visitor
30
32
class Conv2dVisitor (NodeVisitor ):
33
+ """Provide a visitor that lowers ``aten.convolution`` to TOSA.
34
+
35
+ Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate.
36
+
37
+ """
38
+
31
39
target = "aten.convolution.default"
32
40
33
41
tosa_specs = [
@@ -38,13 +46,32 @@ class Conv2dVisitor(NodeVisitor):
38
46
def __init__ (self , * args ):
39
47
super ().__init__ (* args )
40
48
41
- # torch.nn.Conv2d does not require the result of
42
- # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
43
- # to be an integer, but tosa currently strictly require this property.
44
- # This function adjusts the pad value to meet the requirement.
45
49
def adjust_pad_if_needed (
46
50
self , input_size : int , input_weight : int , stride : int , pad : int , dilation : int
47
51
) -> int :
52
+ """Adjust padding to satisfy TOSA's integer output-size requirement.
53
+
54
+ Torch ``Conv2d`` does not require the result of
55
+ ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an
56
+ integer, but TOSA does. This helper reduces the provided padding so
57
+ that the expression becomes divisible by ``stride``.
58
+
59
+ Args:
60
+ input_size (int): Spatial input size along the dimension (H or W).
61
+ input_weight (int): Kernel size along the same dimension.
62
+ stride (int): Stride along the same dimension.
63
+ pad (int): Padding value to adjust (bottom or right after duplication).
64
+ dilation (int): Dilation along the same dimension.
65
+
66
+ Returns:
67
+ int: Adjusted padding value that yields an integer output size.
68
+
69
+ Raises:
70
+ RuntimeError: If the required adjustment exceeds the provided
71
+ padding, which should be handled by the ``SizeAdjustInputPass``
72
+ pass instead.
73
+
74
+ """
48
75
mod_remainder = (
49
76
input_size + 2 * pad - dilation * (input_weight - 1 ) - 1
50
77
) % stride
@@ -55,7 +82,8 @@ def adjust_pad_if_needed(
55
82
56
83
if mod_remainder > pad :
57
84
raise RuntimeError (
58
- "This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
85
+ "This case should be handled by the SizeAdjustInputPass pass, "
86
+ "is it enabled?"
59
87
)
60
88
return pad - mod_remainder
61
89
@@ -66,7 +94,7 @@ def define_node(
66
94
inputs : List [TosaArg ],
67
95
output : TosaArg ,
68
96
) -> None :
69
-
97
+ """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
70
98
import serializer .tosa_serializer as ts # type: ignore
71
99
from tosa .RoundingMode import RoundingMode # type: ignore
72
100
@@ -133,7 +161,7 @@ def define_node(
133
161
in_channels = input .shape [1 ]
134
162
out_channels = weight .shape [0 ]
135
163
if (in_channels == group .number ) and (out_channels % in_channels ) == 0 :
136
- """Depthwise convolution case"""
164
+ """Depthwise convolution case. """
137
165
# Reshape torch shape format of weight tensor to tosa required format.
138
166
# https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
139
167
m_length = int (out_channels / in_channels )
@@ -178,7 +206,7 @@ def define_node(
178
206
acc_type = acc_type ,
179
207
)
180
208
else :
181
- """Regular convolution case"""
209
+ """Regular convolution case. """
182
210
tosa_op = ts .TosaOp .Op ().CONV2D
183
211
weight_name = weight .name
184
212
0 commit comments