44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7+ """Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
8+
79import itertools
810from typing import Any , List
911
2830
2931@register_node_visitor
3032class Conv2dVisitor (NodeVisitor ):
33+ """Provide a visitor that lowers ``aten.convolution`` to TOSA.
34+
35+ Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate.
36+
37+ """
38+
3139 target = "aten.convolution.default"
3240
3341 tosa_specs = [
@@ -38,13 +46,32 @@ class Conv2dVisitor(NodeVisitor):
3846 def __init__ (self , * args ):
3947 super ().__init__ (* args )
4048
41- # torch.nn.Conv2d does not require the result of
42- # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
43- # to be an integer, but tosa currently strictly require this property.
44- # This function adjusts the pad value to meet the requirement.
4549 def adjust_pad_if_needed (
4650 self , input_size : int , input_weight : int , stride : int , pad : int , dilation : int
4751 ) -> int :
52+ """Adjust padding to satisfy TOSA's integer output-size requirement.
53+
54+ Torch ``Conv2d`` does not require the result of
55+ ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an
56+ integer, but TOSA does. This helper reduces the provided padding so
57+ that the expression becomes divisible by ``stride``.
58+
59+ Args:
60+ input_size (int): Spatial input size along the dimension (H or W).
61+ input_weight (int): Kernel size along the same dimension.
62+ stride (int): Stride along the same dimension.
63+ pad (int): Padding value to adjust (bottom or right after duplication).
64+ dilation (int): Dilation along the same dimension.
65+
66+ Returns:
67+ int: Adjusted padding value that yields an integer output size.
68+
69+ Raises:
70+ RuntimeError: If the required adjustment exceeds the provided
71+ padding, which should be handled by the ``SizeAdjustInputPass``
72+ pass instead.
73+
74+ """
4875 mod_remainder = (
4976 input_size + 2 * pad - dilation * (input_weight - 1 ) - 1
5077 ) % stride
@@ -55,7 +82,8 @@ def adjust_pad_if_needed(
5582
5683 if mod_remainder > pad :
5784 raise RuntimeError (
58- "This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
85+ "This case should be handled by the SizeAdjustInputPass pass, "
86+ "is it enabled?"
5987 )
6088 return pad - mod_remainder
6189
@@ -66,7 +94,7 @@ def define_node(
6694 inputs : List [TosaArg ],
6795 output : TosaArg ,
6896 ) -> None :
69-
97+ """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
7098 import serializer .tosa_serializer as ts # type: ignore
7199 from tosa .RoundingMode import RoundingMode # type: ignore
72100
@@ -133,7 +161,7 @@ def define_node(
133161 in_channels = input .shape [1 ]
134162 out_channels = weight .shape [0 ]
135163 if (in_channels == group .number ) and (out_channels % in_channels ) == 0 :
136- """Depthwise convolution case"""
164+ """Depthwise convolution case. """
137165 # Reshape torch shape format of weight tensor to tosa required format.
138166 # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
139167 m_length = int (out_channels / in_channels )
@@ -178,7 +206,7 @@ def define_node(
178206 acc_type = acc_type ,
179207 )
180208 else :
181- """Regular convolution case"""
209+ """Regular convolution case. """
182210 tosa_op = ts .TosaOp .Op ().CONV2D
183211 weight_name = weight .name
184212
0 commit comments