Skip to content

Commit c65e915

Browse files
committed
add doc strings to pytorch-specific padding calculation functions
1 parent e55b29c commit c65e915

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

hls4ml/converters/utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def compute_padding_1d(pad_type, in_size, stride, filt_size):
4545
is odd, it will add the extra column to the right.
4646
4747
Args:
48-
pad_type (str): Padding type, one of ``same``, `valid`` or ``causal`` (case insensitive).
48+
pad_type (str): Padding type, one of ``same``, ``valid`` or ``causal`` (case insensitive).
4949
in_size (int): Input size.
5050
stride (int): Stride length.
5151
filt_size (int): Length of the kernel window.
@@ -135,6 +135,23 @@ def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_widt
135135

136136

137137
def compute_padding_1d_pytorch(pad_type, in_size, stride, filt_size, dilation):
138+
"""Computes the amount of padding required on each side of the 1D input tensor following pytorch conventions.
139+
140+
In case of ``same`` padding, this routine tries to pad evenly left and right, but if the amount of columns to be added
141+
is odd, it will add the extra column to the right.
142+
143+
Args:
144+
pad_type (str or int): Padding type. If string, one of ``same``, ``valid`` or ``causal`` (case insensitive).
145+
in_size (int): Input size.
146+
stride (int): Stride length.
147+
filt_size (int): Length of the kernel window.
148+
149+
Raises:
150+
Exception: Raised if the padding type is unknown.
151+
152+
Returns:
153+
tuple: Tuple containing the padded input size, left and right padding values.
154+
"""
138155
if isinstance(pad_type, str):
139156
if pad_type.lower() == 'same':
140157
n_out = int(
@@ -176,6 +193,26 @@ def compute_padding_1d_pytorch(pad_type, in_size, stride, filt_size, dilation):
176193
def compute_padding_2d_pytorch(
177194
pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width, dilation_height, dilation_width
178195
):
196+
"""Computes the amount of padding required on each side of the 2D input tensor following pytorch conventions.
197+
198+
In case of ``same`` padding, this routine tries to pad evenly left and right (top and bottom), but if the amount of
199+
columns to be added is odd, it will add the extra column to the right/bottom.
200+
201+
Args:
202+
pad_type (str or int): Padding type. If string, one of ``same`` or ``valid`` (case insensitive).
203+
in_height (int): The height of the input tensor.
204+
in_width (int): The width of the input tensor.
205+
stride_height (int): Stride height.
206+
stride_width (int): Stride width.
207+
filt_height (int): Height of the kernel window.
208+
filt_width (int): Width of the kernel window.
209+
210+
Raises:
211+
Exception: Raised if the padding type is unknown.
212+
213+
Returns:
214+
tuple: Tuple containing the padded input height, width, and top, bottom, left and right padding values.
215+
"""
179216
if isinstance(pad_type, str):
180217
if pad_type.lower() == 'same':
181218
# Height

0 commit comments

Comments
 (0)