@@ -112,58 +112,52 @@ def create_operators(self, param, grad):
112
112
113
113
114
114
class GradientClipByGlobalNorm (BaseGradientClipAttr ):
115
- global_norm_var = None
116
- local_norm_var = None
117
- clip_norm_var = None
118
- scale_var = None
119
-
120
- @classmethod
121
- def init (cls , clip_norm ):
122
- if not (isinstance (clip_norm , int ) or isinstance (clip_norm , float )):
123
- raise TypeError ("The 'clip_norm' must be a value of int or float" )
124
-
125
- cls .global_norm_var = layers .fill_constant (
126
- shape = [1 ], dtype = "float32" , value = 0.0 )
127
- cls .local_norm_var = layers .create_tensor (dtype = "float32" )
128
- cls .clip_norm_var = layers .fill_constant (
129
- shape = [1 ], dtype = "float32" , value = clip_norm )
130
-
131
- @classmethod
132
- def check_init (cls ):
133
- if not (isinstance (cls .global_norm_var , framework .Variable ) and
134
- isinstance (cls .local_norm_var , framework .Variable ) and
135
- isinstance (cls .clip_norm_var , framework .Variable )):
136
- raise ValueError (
137
- "Class 'GradientClipByGlobalNorm' has not been properly initialized. \
138
- Please call GradientClipByGlobalNorm.init() first." )
115
+ def __init__ (self , clip_norm , group_name = "default_group" ):
116
+ if not isinstance (group_name , basestring ):
117
+ raise TypeError ("'group_name' must be a basestring." )
118
+
119
+ self .clip_norm = clip_norm
120
+ self .group_name = group_name
139
121
140
122
def process_context (self , context , param , grad ):
141
- cls = self .__class__
142
- cls .check_init ()
123
+ if self .group_name not in context :
124
+ context [self .group_name ] = []
125
+ context [self .group_name + "_clip_value" ] = self .clip_norm
126
+ context [self .group_name + "_clip" ] = layers .fill_constant (
127
+ shape = [1 ], dtype = "float32" , value = self .clip_norm )
128
+ else :
129
+ if not self .clip_norm == context [self .group_name + "_clip_value" ]:
130
+ raise ValueError (
131
+ "All parameters' 'clip_norm' of a same group should be the same"
132
+ )
143
133
144
- cls .local_norm_var = layers .reduce_sum (
145
- input = layers .pow (x = grad , factor = 2.0 ))
146
- layers .sums (
147
- input = [cls .local_norm_var , cls .global_norm_var ],
148
- out = [cls .global_norm_var ])
134
+ local_norm_var = layers .reduce_sum (input = layers .pow (x = grad , factor = 2.0 ))
135
+ context [self .group_name ].append (local_norm_var )
149
136
150
- def create_operators (self , param , grad ):
151
- cls = self .__class__
152
- cls .check_init ()
137
+ self .context = context
153
138
154
- if cls .scale_var is None :
155
- layers .sqrt (x = cls .global_norm_var , out = cls .global_norm_var )
156
- cls .scale_var = layers .elementwise_div (
157
- x = cls .clip_norm_var ,
139
+ def create_operators (self , param , grad ):
140
+ group_scale_name = self .group_name + "_scale"
141
+ if group_scale_name not in self .context :
142
+ group_norm_var = layers .sums (input = self .context [self .group_name ])
143
+ layers .sqrt (x = group_norm_var , out = group_norm_var )
144
+ clip_var = self .context [self .group_name + "_clip" ]
145
+ group_scale_var = layers .elementwise_div (
146
+ x = clip_var ,
158
147
y = layers .elementwise_max (
159
- x = cls .clip_norm_var , y = cls .global_norm_var ))
160
- assert cls .scale_var .shape == (1L , )
148
+ x = clip_var , y = group_norm_var ))
149
+ assert group_scale_var .shape == (1L , )
150
+ self .context [group_scale_name ] = group_scale_var
161
151
162
- new_grad = layers .elementwise_mul (x = grad , y = cls .scale_var )
152
+ new_grad = layers .elementwise_mul (
153
+ x = grad , y = self .context [group_scale_name ])
163
154
return param , new_grad
164
155
165
156
166
- def gradient_clip_by_global_norm (clip_norm , param_list = None , program = None ):
157
+ def gradient_clip_by_global_norm (clip_norm ,
158
+ param_list = None ,
159
+ group_name = "default_group" ,
160
+ program = None ):
167
161
if program is None :
168
162
program = framework .default_main_program ()
169
163
if param_list is None :
@@ -175,9 +169,9 @@ def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
175
169
"'param_list' should be a list of Parameter or basestring(parameter's name)."
176
170
)
177
171
178
- GradientClipByGlobalNorm .init (clip_norm )
179
172
for param in param_list :
180
- param .gradient_clip_attr = GradientClipByGlobalNorm ()
173
+ param .gradient_clip_attr = GradientClipByGlobalNorm (clip_norm ,
174
+ group_name )
181
175
182
176
183
177
def append_gradient_clip_ops (param_grad ):
0 commit comments