Skip to content

Commit d58c1c5

Browse files
committed
Update Conv2d layer
1 parent ae8fc36 commit d58c1c5

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

cvt_tensorflow/models/layers/utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(
182182
out_channels: int,
183183
kernel_size: Union[int, tuple, list],
184184
stride: Union[int, tuple, list] = 1,
185-
padding: int = 0,
185+
padding: Union[int, tuple, list] = 0,
186186
dilation: Union[int, tuple, list] = 1,
187187
groups: int = 1,
188188
bias: bool = True,
@@ -203,16 +203,22 @@ def __init__(
203203
Number of channels produced by the convolution.
204204
kernel_size : Union[int, tuple, list]
205205
Size of the convolving kernel.
206+
If tuple/list of two ints, the first int is used for the height
207+
dimension, and the second int for the width dimension.
206208
stride : Union[int, tuple, list], optional
207209
Stride of the convolution.
210+
If tuple/list of two ints, the first int is used for the height
211+
dimension, and the second int for the width dimension.
208212
The default is 1.
209-
padding : int, optional
213+
padding : Union[int, tuple, list], optional
210214
Padding added to all four sides of the input.
215+
If tuple/list of two ints, the first int is used for the height
216+
dimension, and the second int for the width dimension.
211217
The default is 0.
212218
dilation : Union[int, tuple, list], optional
213-
An integer or tuple/list of 2 integers, specifying the dilation
214-
rate to use for dilated convolution. Can be a single integer to
215-
specify the same value for all spatial dimensions.
219+
Spacing between kernel elements.
220+
If tuple/list of two ints, the first int is used for the height
221+
dimension, and the second int for the width dimension.
216222
The default is 1.
217223
groups : int, optional
218224
A positive integer specifying the number of groups in which
@@ -264,10 +270,9 @@ def __init__(
264270
self.data_format = "channels_last"
265271

266272
# Pad Layer
267-
if self.padding > 0:
268-
self.pad_layer = tf.keras.layers.ZeroPadding2D(
269-
padding=padding, data_format=self.data_format
270-
)
273+
self.pad_layer = tf.keras.layers.ZeroPadding2D(
274+
padding=self.padding, data_format=self.data_format
275+
)
271276

272277
self.conv_layer = tf.keras.layers.Conv2D(
273278
filters=self.out_channels,
@@ -295,10 +300,7 @@ def uniform_initializer_spec(self):
295300
def call(self, inputs, *args, **kwargs):
296301
if self.data_format == "channels_last":
297302
inputs = _to_channel_last(inputs)
298-
if self.padding > 0:
299-
x = self.pad_layer(inputs)
300-
else:
301-
x = inputs
303+
x = self.pad_layer(inputs)
302304
x = self.conv_layer(x)
303305
if self.data_format == "channels_last":
304306
x = _to_channel_first(x)

0 commit comments

Comments
 (0)