Skip to content

Commit 5d30142

Browse files
committed
follow comment from panxin
1 parent 470d671 commit 5d30142

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

python/paddle/fluid/layers/utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,30 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
1415

1516

16-
def convert_to_list(value, n, name):
17-
"""Converts a single integer or iterable of integers into an integer list.
17+
def convert_to_list(value, n, name, dtype=np.int):
18+
"""
19+
Converts a single numerical type or iterable of numerical
20+
types into an numerical type list.
1821
1922
Arguments:
2023
value: The value to validate and convert. Could an int, or any iterable
2124
of ints.
2225
n: The size of the list to be returned.
2326
name: The name of the argument being validated, e.g. "stride" or
2427
"filter_size". This is only used to format error messages.
28+
dtype: the numerical type of the element of the list to be returned.
2529
2630
Returns:
27-
A list of n integers.
31+
A list of n dtypes.
2832
2933
Raises:
3034
ValueError: If something else than an int/long or iterable thereof was
3135
passed.
3236
"""
33-
if isinstance(value, int):
37+
if isinstance(value, dtype):
3438
return [value, ] * n
3539
else:
3640
try:
@@ -44,11 +48,12 @@ def convert_to_list(value, n, name):
4448
". Received: " + str(value))
4549
for single_value in value_list:
4650
try:
47-
int(single_value)
51+
dtype(single_value)
4852
except (ValueError, TypeError):
4953
raise ValueError(
5054
"The " + name + "'s type must be a list or tuple of " + str(
51-
n) + " integers. Received: " + str(value) + " "
55+
n) + " " + str(dtype) + " . Received: " + str(
56+
value) + " "
5257
"including element " + str(single_value) + " of type" + " "
5358
+ str(type(single_value)))
5459
return value_list

0 commit comments

Comments
 (0)