@@ -62,10 +62,10 @@ def noam_decay(d_model, warmup_steps):
62
62
The decayed learning rate.
63
63
"""
64
64
global_step = _decay_step_counter (1 )
65
- with init_on_cpu ():
66
- a = global_step ** - 0.5
67
- b = (warmup_steps ** - 1.5 ) * global_step
68
- lr_value = (d_model ** - 0.5 ) * ops .elementwise_min (a , b )
65
+
66
+ a = global_step ** - 0.5
67
+ b = (warmup_steps ** - 1.5 ) * global_step
68
+ lr_value = (d_model ** - 0.5 ) * ops .elementwise_min (a , b )
69
69
70
70
return lr_value
71
71
@@ -108,12 +108,10 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
108
108
"""
109
109
global_step = _decay_step_counter ()
110
110
111
- with init_on_cpu ():
112
- # update learning_rate
113
- div_res = global_step / decay_steps
114
- if staircase :
115
- div_res = ops .floor (div_res )
116
- decayed_lr = learning_rate * (decay_rate ** div_res )
111
+ div_res = global_step / decay_steps
112
+ if staircase :
113
+ div_res = ops .floor (div_res )
114
+ decayed_lr = learning_rate * (decay_rate ** div_res )
117
115
118
116
return decayed_lr
119
117
@@ -138,11 +136,10 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
138
136
"""
139
137
global_step = _decay_step_counter ()
140
138
141
- with init_on_cpu ():
142
- div_res = global_step / decay_steps
143
- if staircase :
144
- div_res = ops .floor (div_res )
145
- decayed_lr = learning_rate * ops .exp (- 1 * decay_rate * div_res )
139
+ div_res = global_step / decay_steps
140
+ if staircase :
141
+ div_res = ops .floor (div_res )
142
+ decayed_lr = learning_rate * ops .exp (- 1 * decay_rate * div_res )
146
143
147
144
return decayed_lr
148
145
@@ -184,12 +181,11 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
184
181
"""
185
182
global_step = _decay_step_counter ()
186
183
187
- with init_on_cpu ():
188
- div_res = global_step / decay_steps
189
- if staircase :
190
- div_res = ops .floor (div_res )
184
+ div_res = global_step / decay_steps
185
+ if staircase :
186
+ div_res = ops .floor (div_res )
191
187
192
- decayed_lr = learning_rate / (1 + decay_rate * div_res )
188
+ decayed_lr = learning_rate / (1 + decay_rate * div_res )
193
189
194
190
return decayed_lr
195
191
@@ -224,25 +220,22 @@ def polynomial_decay(learning_rate,
224
220
"""
225
221
global_step = _decay_step_counter ()
226
222
227
- with init_on_cpu ():
228
- if cycle :
229
- div_res = ops .ceil (global_step / decay_steps )
230
- zero_var = tensor .fill_constant (
231
- shape = [1 ], dtype = 'float32' , value = 0.0 )
232
- one_var = tensor .fill_constant (
233
- shape = [1 ], dtype = 'float32' , value = 1.0 )
234
-
235
- with control_flow .Switch () as switch :
236
- with switch .case (global_step == zero_var ):
237
- tensor .assign (input = one_var , output = div_res )
238
- decay_steps = decay_steps * div_res
239
- else :
240
- decay_steps_var = tensor .fill_constant (
241
- shape = [1 ], dtype = 'float32' , value = float (decay_steps ))
242
- global_step = ops .elementwise_min (x = global_step , y = decay_steps_var )
223
+ if cycle :
224
+ div_res = ops .ceil (global_step / decay_steps )
225
+ zero_var = tensor .fill_constant (shape = [1 ], dtype = 'float32' , value = 0.0 )
226
+ one_var = tensor .fill_constant (shape = [1 ], dtype = 'float32' , value = 1.0 )
243
227
244
- decayed_lr = (learning_rate - end_learning_rate ) * \
245
- ((1 - global_step / decay_steps ) ** power ) + end_learning_rate
228
+ with control_flow .Switch () as switch :
229
+ with switch .case (global_step == zero_var ):
230
+ tensor .assign (input = one_var , output = div_res )
231
+ decay_steps = decay_steps * div_res
232
+ else :
233
+ decay_steps_var = tensor .fill_constant (
234
+ shape = [1 ], dtype = 'float32' , value = float (decay_steps ))
235
+ global_step = ops .elementwise_min (x = global_step , y = decay_steps_var )
236
+
237
+ decayed_lr = (learning_rate - end_learning_rate ) * \
238
+ ((1 - global_step / decay_steps ) ** power ) + end_learning_rate
246
239
return decayed_lr
247
240
248
241
@@ -277,28 +270,28 @@ def piecewise_decay(boundaries, values):
277
270
278
271
global_step = _decay_step_counter ()
279
272
280
- with init_on_cpu ():
281
- lr = tensor .create_global_var (
282
- shape = [1 ],
283
- value = 0.0 ,
284
- dtype = 'float32' ,
285
- persistable = True ,
286
- name = "learning_rate" )
273
+ lr = tensor .create_global_var (
274
+ shape = [1 ],
275
+ value = 0.0 ,
276
+ dtype = 'float32' ,
277
+ persistable = True ,
278
+ name = "learning_rate" )
287
279
288
- with control_flow .Switch () as switch :
289
- for i in range (len (boundaries )):
290
- boundary_val = tensor .fill_constant (
291
- shape = [1 ], dtype = 'float32' , value = float (boundaries [i ]))
292
- value_var = tensor .fill_constant (
293
- shape = [1 ], dtype = 'float32' , value = float (values [i ]))
294
- with switch .case (global_step < boundary_val ):
295
- tensor .assign (value_var , lr )
296
- last_value_var = tensor .fill_constant (
280
+ with control_flow .Switch () as switch :
281
+ for i in range (len (boundaries )):
282
+ boundary_val = tensor .fill_constant (
297
283
shape = [1 ],
298
284
dtype = 'float32' ,
299
- value = float (values [len (values ) - 1 ]))
300
- with switch .default ():
301
- tensor .assign (last_value_var , lr )
285
+ value = float (boundaries [i ]),
286
+ force_cpu = True )
287
+ value_var = tensor .fill_constant (
288
+ shape = [1 ], dtype = 'float32' , value = float (values [i ]))
289
+ with switch .case (global_step < boundary_val ):
290
+ tensor .assign (value_var , lr )
291
+ last_value_var = tensor .fill_constant (
292
+ shape = [1 ], dtype = 'float32' , value = float (values [len (values ) - 1 ]))
293
+ with switch .default ():
294
+ tensor .assign (last_value_var , lr )
302
295
303
296
return lr
304
297
@@ -333,9 +326,9 @@ def _balanced_weight(param_norm, grad_norm):
333
326
grad_norm = ops .sqrt (nn .reduce_sum (input = ops .square (grad )))
334
327
if type (param_lr ) == float and param_lr == 1.0 :
335
328
decayed_lr = learning_rate * param_norm \
336
- / _balanced_weight (param_norm , grad_norm )
329
+ / _balanced_weight (param_norm , grad_norm )
337
330
else :
338
331
decayed_lr = learning_rate * param_lr * param_norm \
339
- / _balanced_weight (param_norm , grad_norm )
332
+ / _balanced_weight (param_norm , grad_norm )
340
333
# set back param local learning rate
341
334
param .optimize_attr ['learning_rate' ] = decayed_lr
0 commit comments