19
19
20
20
21
21
def convert_and_compare (x , Type ):
22
- """
23
- Convert x to be the same type as Type and then convert back to
24
- check whether there is a loss of information
25
- :param x: object to be checked
26
- :param Type: target type to check x over
27
-
22
+ """
23
+ Convert x to be the same type as Type and then convert back to
24
+ check whether there is a loss of information
25
+ :param x: object to be checked
26
+ :param Type: target type to check x over
27
+
28
28
"""
29
29
return type (x )(Type (x )) == x
30
30
31
31
32
32
def is_compatible_with (x , Type ):
33
- """
34
- Check if x has a type compatible with Type
35
- :param x: object to be checked
36
- :param Type: target type to check x over
37
-
33
+ """
34
+ Check if x has a type compatible with Type
35
+ :param x: object to be checked
36
+ :param Type: target type to check x over
37
+
38
38
"""
39
39
if type (x ) == Type :
40
40
return True
41
41
try :
42
42
if float == Type or int == Type :
43
- # avoid those types that can be converted to float/int but not very
44
- # meaningful and could potentially lead to error
45
- # i.e., str and bool typed value should not be used for initializing float/int variable
43
+ # avoid those types that can be converted to float/int but not very
44
+ # meaningful and could potentially lead to error
45
+ # i.e., str and bool typed value should not be used for initializing float/int variable
46
46
if not isinstance (x , str ) and not isinstance (x , bool ):
47
47
return convert_and_compare (x , Type )
48
48
elif bool == Type :
49
- # should not use string type to initialize bool variable
49
+ # should not use string type to initialize bool variable
50
50
if not isinstance (x , str ):
51
51
return convert_and_compare (x , Type )
52
52
else :
@@ -88,6 +88,10 @@ class ParameterAttribute(object):
88
88
:type learning_rate: float or None
89
89
:param momentum: The parameter momentum. None means use global value.
90
90
:type momentum: float or None
91
+ :param gradient_clipping_threshold: gradient clipping threshold. If gradient
92
+ value larger than some value, will be
93
+ clipped.
94
+ :type gradient_clipping_threshold: float
91
95
:param sparse_update: Enable sparse update for this parameter. It will
92
96
enable both local and remote sparse update.
93
97
:type sparse_update: bool
@@ -104,6 +108,7 @@ def __init__(self,
104
108
l2_rate = None ,
105
109
learning_rate = None ,
106
110
momentum = None ,
111
+ gradient_clipping_threshold = None ,
107
112
sparse_update = False ):
108
113
# initialize strategy.
109
114
if is_static :
@@ -152,6 +157,11 @@ def __init__(self,
152
157
self .attr ['sparse_update' ] = True
153
158
self .attr ['sparse_remote_update' ] = True
154
159
160
+ if gradient_clipping_threshold is not None and \
161
+ is_compatible_with (gradient_clipping_threshold , float ):
162
+ self .attr ['gradient_clipping_threshold' ] = \
163
+ gradient_clipping_threshold
164
+
155
165
def set_default_parameter_name (self , name ):
156
166
"""
157
167
Set default parameter name. If parameter not set, then will use default
0 commit comments