11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import numpy as np
14
15
15
16
16
- def convert_to_list (value , n , name ):
17
- """Converts a single integer or iterable of integers into an integer list.
17
+ def convert_to_list (value , n , name , dtype = np .int ):
18
+ """
19
+ Converts a single numerical type or iterable of numerical
20
+ types into an numerical type list.
18
21
19
22
Arguments:
20
23
value: The value to validate and convert. Could an int, or any iterable
21
24
of ints.
22
25
n: The size of the list to be returned.
23
26
name: The name of the argument being validated, e.g. "stride" or
24
27
"filter_size". This is only used to format error messages.
28
+ dtype: the numerical type of the element of the list to be returned.
25
29
26
30
Returns:
27
- A list of n integers .
31
+ A list of n dtypes .
28
32
29
33
Raises:
30
34
ValueError: If something else than an int/long or iterable thereof was
31
35
passed.
32
36
"""
33
- if isinstance (value , int ):
37
+ if isinstance (value , dtype ):
34
38
return [value , ] * n
35
39
else :
36
40
try :
@@ -44,11 +48,12 @@ def convert_to_list(value, n, name):
44
48
". Received: " + str (value ))
45
49
for single_value in value_list :
46
50
try :
47
- int (single_value )
51
+ dtype (single_value )
48
52
except (ValueError , TypeError ):
49
53
raise ValueError (
50
54
"The " + name + "'s type must be a list or tuple of " + str (
51
- n ) + " integers. Received: " + str (value ) + " "
55
+ n ) + " " + str (dtype ) + " . Received: " + str (
56
+ value ) + " "
52
57
"including element " + str (single_value ) + " of type" + " "
53
58
+ str (type (single_value )))
54
59
return value_list
0 commit comments