@@ -45,7 +45,7 @@ def compute_padding_1d(pad_type, in_size, stride, filt_size):
4545 is odd, it will add the extra column to the right.
4646
4747 Args:
48- pad_type (str): Padding type, one of ``same``, `valid`` or ``causal`` (case insensitive).
48+ pad_type (str): Padding type, one of ``same``, `` valid`` or ``causal`` (case insensitive).
4949 in_size (int): Input size.
5050 stride (int): Stride length.
5151 filt_size (int): Length of the kernel window.
@@ -135,6 +135,23 @@ def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_widt
135135
136136
137137def compute_padding_1d_pytorch (pad_type , in_size , stride , filt_size , dilation ):
138+ """Computes the amount of padding required on each side of the 1D input tensor following pytorch conventions.
139+
140+ In case of ``same`` padding, this routine tries to pad evenly left and right, but if the amount of columns to be added
141+ is odd, it will add the extra column to the right.
142+
143+ Args:
144+ pad_type (str or int): Padding type. If string, one of ``same``, ``valid`` or ``causal`` (case insensitive).
145+ in_size (int): Input size.
146+ stride (int): Stride length.
147+ filt_size (int): Length of the kernel window.
148+
149+ Raises:
150+ Exception: Raised if the padding type is unknown.
151+
152+ Returns:
153+ tuple: Tuple containing the padded input size, left and right padding values.
154+ """
138155 if isinstance (pad_type , str ):
139156 if pad_type .lower () == 'same' :
140157 n_out = int (
@@ -176,6 +193,26 @@ def compute_padding_1d_pytorch(pad_type, in_size, stride, filt_size, dilation):
176193def compute_padding_2d_pytorch (
177194 pad_type , in_height , in_width , stride_height , stride_width , filt_height , filt_width , dilation_height , dilation_width
178195):
196+ """Computes the amount of padding required on each side of the 2D input tensor following pytorch conventions.
197+
198+ In case of ``same`` padding, this routine tries to pad evenly left and right (top and bottom), but if the amount of
199+ columns to be added is odd, it will add the extra column to the right/bottom.
200+
201+ Args:
202+ pad_type (str or int): Padding type. If string, one of ``same`` or ``valid`` (case insensitive).
203+ in_height (int): The height of the input tensor.
204+ in_width (int): The width of the input tensor.
205+ stride_height (int): Stride height.
206+ stride_width (int): Stride width.
207+ filt_height (int): Height of the kernel window.
208+ filt_width (int): Width of the kernel window.
209+
210+ Raises:
211+ Exception: Raised if the padding type is unknown.
212+
213+ Returns:
214+ tuple: Tuple containing the padded input height, width, and top, bottom, left and right padding values.
215+ """
179216 if isinstance (pad_type , str ):
180217 if pad_type .lower () == 'same' :
181218 # Height
0 commit comments