Skip to content

Commit 087d8e7

Browse files
author
chengduo
authored
Merge pull request #8551 from chengduoZH/fixbug/conv2d_python
Fix the bug of conv2d
2 parents fe7c181 + 5d30142 commit 087d8e7

File tree

2 files changed

+95
-50
lines changed

2 files changed

+95
-50
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 36 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..param_attr import ParamAttr
2222
from layer_function_generator import autodoc
2323
from tensor import concat
24+
import utils
2425

2526
__all__ = [
2627
'fc',
@@ -1138,8 +1139,8 @@ def sequence_conv(input,
11381139
def conv2d(input,
11391140
num_filters,
11401141
filter_size,
1141-
stride=None,
1142-
padding=None,
1142+
stride=1,
1143+
padding=0,
11431144
groups=None,
11441145
param_attr=None,
11451146
bias_attr=None,
@@ -1252,12 +1253,10 @@ def conv2d(input,
12521253
raise ValueError("num_channels must be divisible by groups.")
12531254
num_filter_channels = num_channels / groups
12541255

1255-
if isinstance(filter_size, int):
1256-
filter_size = [filter_size, filter_size]
1257-
if isinstance(stride, int):
1258-
stride = [stride, stride]
1259-
if isinstance(padding, int):
1260-
padding = [padding, padding]
1256+
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
1257+
stride = utils.convert_to_list(stride, 2, 'stride')
1258+
padding = utils.convert_to_list(padding, 2, 'padding')
1259+
12611260
if not isinstance(use_cudnn, bool):
12621261
raise ValueError("use_cudnn should be True or False")
12631262

@@ -1432,31 +1431,31 @@ def sequence_last_step(input):
14321431

14331432

14341433
def pool2d(input,
1435-
pool_size,
1436-
pool_type,
1437-
pool_stride=None,
1438-
pool_padding=None,
1434+
pool_size=-1,
1435+
pool_type="max",
1436+
pool_stride=1,
1437+
pool_padding=0,
14391438
global_pooling=False,
14401439
use_cudnn=True,
14411440
name=None):
14421441
"""
14431442
This function adds the operator for pooling in 2 dimensions, using the
14441443
pooling configurations mentioned in input parameters.
14451444
"""
1446-
if pool_padding is None:
1447-
pool_padding = [0, 0]
1448-
if pool_stride is None:
1449-
pool_stride = [1, 1]
14501445
if pool_type not in ["max", "avg"]:
14511446
raise ValueError(
14521447
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
14531448
str(pool_type))
1454-
if isinstance(pool_size, int):
1455-
pool_size = [pool_size, pool_size]
1456-
if isinstance(pool_stride, int):
1457-
pool_stride = [pool_stride, pool_stride]
1458-
if isinstance(pool_padding, int):
1459-
pool_padding = [pool_padding, pool_padding]
1449+
1450+
if global_pooling is False and pool_size == -1:
1451+
raise ValueError(
1452+
"When the global_pooling is False, pool_size must be passed "
1453+
"and be a valid value. Received pool_size: " + str(pool_size))
1454+
1455+
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
1456+
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
1457+
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
1458+
14601459
if not isinstance(use_cudnn, bool):
14611460
raise ValueError("use_cudnn should be True or False")
14621461

@@ -1685,9 +1684,9 @@ def conv2d_transpose(input,
16851684
num_filters,
16861685
output_size=None,
16871686
filter_size=None,
1688-
padding=None,
1689-
stride=None,
1690-
dilation=None,
1687+
padding=0,
1688+
stride=1,
1689+
dilation=1,
16911690
param_attr=None,
16921691
use_cudnn=True,
16931692
name=None):
@@ -1783,37 +1782,19 @@ def conv2d_transpose(input,
17831782
raise TypeError("Input of conv2d_transpose must be Variable")
17841783
input_channel = input.shape[1]
17851784

1786-
op_attr = dict()
1787-
1788-
if isinstance(padding, int):
1789-
op_attr['paddings'] = [padding, padding]
1790-
elif padding is not None:
1791-
op_attr['paddings'] = padding
1792-
1793-
if isinstance(stride, int):
1794-
op_attr['strides'] = [stride, stride]
1795-
elif stride is not None:
1796-
op_attr['strides'] = stride
1797-
1798-
if isinstance(dilation, int):
1799-
op_attr['dilations'] = [dilation, dilation]
1800-
elif dilation is not None:
1801-
op_attr['dilations'] = dilation
1785+
padding = utils.convert_to_list(padding, 2, 'padding')
1786+
stride = utils.convert_to_list(stride, 2, 'stride')
1787+
dilation = utils.convert_to_list(dilation, 2, 'dilation')
18021788

18031789
if not isinstance(use_cudnn, bool):
18041790
raise ValueError("use_cudnn should be True or False")
1805-
op_attr['use_cudnn'] = use_cudnn
18061791

18071792
if filter_size is None:
18081793
if output_size is None:
18091794
raise ValueError("output_size must be set when filter_size is None")
18101795
if isinstance(output_size, int):
18111796
output_size = [output_size, output_size]
18121797

1813-
padding = op_attr.get('paddings', [0, 0])
1814-
stride = op_attr.get('strides', [1, 1])
1815-
dilation = op_attr.get('dilations', [1, 1])
1816-
18171798
h_in = input.shape[2]
18181799
w_in = input.shape[3]
18191800

@@ -1822,9 +1803,9 @@ def conv2d_transpose(input,
18221803
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
18231804
padding[1] - 1) / dilation[1] + 1
18241805
filter_size = [filter_size_h, filter_size_w]
1825-
1826-
elif isinstance(filter_size, int):
1827-
filter_size = [filter_size, filter_size]
1806+
else:
1807+
filter_size = utils.convert_to_list(filter_size, 2,
1808+
'conv2d_transpose.filter_size')
18281809

18291810
filter_shape = [input_channel, num_filters] + filter_size
18301811
img_filter = helper.create_parameter(
@@ -1836,7 +1817,12 @@ def conv2d_transpose(input,
18361817
inputs={'Input': [input],
18371818
'Filter': [img_filter]},
18381819
outputs={'Output': out},
1839-
attrs=op_attr)
1820+
attrs={
1821+
'strides': stride,
1822+
'paddings': padding,
1823+
'dilations': dilation,
1824+
'use_cudnn': use_cudnn
1825+
})
18401826

18411827
return out
18421828

python/paddle/fluid/layers/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
16+
17+
def convert_to_list(value, n, name, dtype=np.int):
18+
"""
19+
Converts a single numerical type or iterable of numerical
20+
types into an numerical type list.
21+
22+
Arguments:
23+
value: The value to validate and convert. Could an int, or any iterable
24+
of ints.
25+
n: The size of the list to be returned.
26+
name: The name of the argument being validated, e.g. "stride" or
27+
"filter_size". This is only used to format error messages.
28+
dtype: the numerical type of the element of the list to be returned.
29+
30+
Returns:
31+
A list of n dtypes.
32+
33+
Raises:
34+
ValueError: If something else than an int/long or iterable thereof was
35+
passed.
36+
"""
37+
if isinstance(value, dtype):
38+
return [value, ] * n
39+
else:
40+
try:
41+
value_list = list(value)
42+
except TypeError:
43+
raise ValueError("The " + name +
44+
"'s type must be list or tuple. Received: " + str(
45+
value))
46+
if len(value_list) != n:
47+
raise ValueError("The " + name + "'s length must be " + str(n) +
48+
". Received: " + str(value))
49+
for single_value in value_list:
50+
try:
51+
dtype(single_value)
52+
except (ValueError, TypeError):
53+
raise ValueError(
54+
"The " + name + "'s type must be a list or tuple of " + str(
55+
n) + " " + str(dtype) + " . Received: " + str(
56+
value) + " "
57+
"including element " + str(single_value) + " of type" + " "
58+
+ str(type(single_value)))
59+
return value_list

0 commit comments

Comments
 (0)