Skip to content

Commit 3fe2def

Browse files
authored
Merge pull request #14540 from JiabinYang/fix_pool2d_doc
Fix pool2d doc and add pool2d test in test_layers
2 parents 61c5f13 + 510601b commit 3fe2def

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,11 +2134,16 @@ def pool2d(input,
21342134
input tensor is NCHW, where N is batch size, C is
21352135
the number of channels, H is the height of the
21362136
feature, and W is the width of the feature.
2137-
pool_size (int): The side length of pooling windows. All pooling
2138-
windows are squares with pool_size on a side.
2137+
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
2138+
it must contain two integers, (pool_size_Height, pool_size_Width).
2139+
Otherwise, the pool kernel size will be a square of an int.
21392140
pool_type: ${pooling_type_comment}
2140-
pool_stride (int): stride of the pooling layer.
2141-
pool_padding (int): padding size.
2141+
pool_stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
2142+
it must contain two integers, (pool_stride_Height, pool_stride_Width).
2143+
Otherwise, the pool stride size will be a square of an int.
2144+
pool_padding (int|list|tuple): The pool padding size. If pool padding size is a tuple,
2145+
it must contain two integers, (pool_padding_on_Height, pool_padding_on_Width).
2146+
Otherwise, the pool padding size will be a square of an int.
21422147
global_pooling (bool): ${global_pooling_comment}
21432148
use_cudnn (bool): ${use_cudnn_comment}
21442149
ceil_mode (bool): ${ceil_mode_comment}

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,17 @@ def test_sequence_unpad(self):
202202
self.assertIsNotNone(layers.sequence_unpad(x=x, length=length))
203203
print(str(program))
204204

205+
def test_pool2d(self):
206+
program = Program()
207+
with program_guard(program):
208+
x = layers.data(name='x', shape=[3, 224, 224], dtype='float32')
209+
self.assertIsNotNone(
210+
layers.pool2d(
211+
x,
212+
pool_size=[5, 3],
213+
pool_stride=[1, 2],
214+
pool_padding=(2, 1)))
215+
205216
def test_lstm_unit(self):
206217
program = Program()
207218
with program_guard(program):

0 commit comments

Comments
 (0)