@@ -277,28 +277,28 @@ def piecewise_decay(boundaries, values):
277
277
278
278
global_step = _decay_step_counter ()
279
279
280
- with init_on_cpu ():
281
- lr = tensor .create_global_var (
282
- shape = [1 ],
283
- value = 0.0 ,
284
- dtype = 'float32' ,
285
- persistable = True ,
286
- name = "learning_rate" )
287
-
288
- with control_flow .Switch () as switch :
289
- for i in range (len (boundaries )):
290
- boundary_val = tensor .fill_constant (
291
- shape = [1 ], dtype = 'float32' , value = float (boundaries [i ]))
292
- value_var = tensor .fill_constant (
293
- shape = [1 ], dtype = 'float32' , value = float (values [i ]))
294
- with switch .case (global_step < boundary_val ):
295
- tensor .assign (value_var , lr )
296
- last_value_var = tensor .fill_constant (
280
+ lr = tensor .create_global_var (
281
+ shape = [1 ],
282
+ value = 0.0 ,
283
+ dtype = 'float32' ,
284
+ persistable = True ,
285
+ name = "learning_rate" )
286
+
287
+ with control_flow .Switch () as switch :
288
+ for i in range (len (boundaries )):
289
+ boundary_val = tensor .fill_constant (
297
290
shape = [1 ],
298
291
dtype = 'float32' ,
299
- value = float (values [len (values ) - 1 ]))
300
- with switch .default ():
301
- tensor .assign (last_value_var , lr )
292
+ value = float (boundaries [i ]),
293
+ force_cpu = True )
294
+ value_var = tensor .fill_constant (
295
+ shape = [1 ], dtype = 'float32' , value = float (values [i ]))
296
+ with switch .case (global_step < boundary_val ):
297
+ tensor .assign (value_var , lr )
298
+ last_value_var = tensor .fill_constant (
299
+ shape = [1 ], dtype = 'float32' , value = float (values [len (values ) - 1 ]))
300
+ with switch .default ():
301
+ tensor .assign (last_value_var , lr )
302
302
303
303
return lr
304
304
@@ -333,9 +333,9 @@ def _balanced_weight(param_norm, grad_norm):
333
333
grad_norm = ops .sqrt (nn .reduce_sum (input = ops .square (grad )))
334
334
if type (param_lr ) == float and param_lr == 1.0 :
335
335
decayed_lr = learning_rate * param_norm \
336
- / _balanced_weight (param_norm , grad_norm )
336
+ / _balanced_weight (param_norm , grad_norm )
337
337
else :
338
338
decayed_lr = learning_rate * param_lr * param_norm \
339
- / _balanced_weight (param_norm , grad_norm )
339
+ / _balanced_weight (param_norm , grad_norm )
340
340
# set back param local learning rate
341
341
param .optimize_attr ['learning_rate' ] = decayed_lr
0 commit comments