@@ -62,10 +62,10 @@ def noam_decay(d_model, warmup_steps):
62
62
The decayed learning rate.
63
63
"""
64
64
global_step = _decay_step_counter (1 )
65
- with init_on_cpu ():
66
- a = global_step ** - 0.5
67
- b = (warmup_steps ** - 1.5 ) * global_step
68
- lr_value = (d_model ** - 0.5 ) * ops .elementwise_min (a , b )
65
+
66
+ a = global_step ** - 0.5
67
+ b = (warmup_steps ** - 1.5 ) * global_step
68
+ lr_value = (d_model ** - 0.5 ) * ops .elementwise_min (a , b )
69
69
70
70
return lr_value
71
71
@@ -108,12 +108,10 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
108
108
"""
109
109
global_step = _decay_step_counter ()
110
110
111
- with init_on_cpu ():
112
- # update learning_rate
113
- div_res = global_step / decay_steps
114
- if staircase :
115
- div_res = ops .floor (div_res )
116
- decayed_lr = learning_rate * (decay_rate ** div_res )
111
+ div_res = global_step / decay_steps
112
+ if staircase :
113
+ div_res = ops .floor (div_res )
114
+ decayed_lr = learning_rate * (decay_rate ** div_res )
117
115
118
116
return decayed_lr
119
117
@@ -138,11 +136,10 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
138
136
"""
139
137
global_step = _decay_step_counter ()
140
138
141
- with init_on_cpu ():
142
- div_res = global_step / decay_steps
143
- if staircase :
144
- div_res = ops .floor (div_res )
145
- decayed_lr = learning_rate * ops .exp (- 1 * decay_rate * div_res )
139
+ div_res = global_step / decay_steps
140
+ if staircase :
141
+ div_res = ops .floor (div_res )
142
+ decayed_lr = learning_rate * ops .exp (- 1 * decay_rate * div_res )
146
143
147
144
return decayed_lr
148
145
@@ -184,12 +181,11 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
184
181
"""
185
182
global_step = _decay_step_counter ()
186
183
187
- with init_on_cpu ():
188
- div_res = global_step / decay_steps
189
- if staircase :
190
- div_res = ops .floor (div_res )
184
+ div_res = global_step / decay_steps
185
+ if staircase :
186
+ div_res = ops .floor (div_res )
191
187
192
- decayed_lr = learning_rate / (1 + decay_rate * div_res )
188
+ decayed_lr = learning_rate / (1 + decay_rate * div_res )
193
189
194
190
return decayed_lr
195
191
@@ -224,25 +220,22 @@ def polynomial_decay(learning_rate,
224
220
"""
225
221
global_step = _decay_step_counter ()
226
222
227
- with init_on_cpu ():
228
- if cycle :
229
- div_res = ops .ceil (global_step / decay_steps )
230
- zero_var = tensor .fill_constant (
231
- shape = [1 ], dtype = 'float32' , value = 0.0 )
232
- one_var = tensor .fill_constant (
233
- shape = [1 ], dtype = 'float32' , value = 1.0 )
234
-
235
- with control_flow .Switch () as switch :
236
- with switch .case (global_step == zero_var ):
237
- tensor .assign (input = one_var , output = div_res )
238
- decay_steps = decay_steps * div_res
239
- else :
240
- decay_steps_var = tensor .fill_constant (
241
- shape = [1 ], dtype = 'float32' , value = float (decay_steps ))
242
- global_step = ops .elementwise_min (x = global_step , y = decay_steps_var )
243
-
244
- decayed_lr = (learning_rate - end_learning_rate ) * \
245
- ((1 - global_step / decay_steps ) ** power ) + end_learning_rate
223
+ if cycle :
224
+ div_res = ops .ceil (global_step / decay_steps )
225
+ zero_var = tensor .fill_constant (shape = [1 ], dtype = 'float32' , value = 0.0 )
226
+ one_var = tensor .fill_constant (shape = [1 ], dtype = 'float32' , value = 1.0 )
227
+
228
+ with control_flow .Switch () as switch :
229
+ with switch .case (global_step == zero_var ):
230
+ tensor .assign (input = one_var , output = div_res )
231
+ decay_steps = decay_steps * div_res
232
+ else :
233
+ decay_steps_var = tensor .fill_constant (
234
+ shape = [1 ], dtype = 'float32' , value = float (decay_steps ))
235
+ global_step = ops .elementwise_min (x = global_step , y = decay_steps_var )
236
+
237
+ decayed_lr = (learning_rate - end_learning_rate ) * \
238
+ ((1 - global_step / decay_steps ) ** power ) + end_learning_rate
246
239
return decayed_lr
247
240
248
241
0 commit comments