@@ -47,6 +47,8 @@ def __init__(self, learning_rate, regularization=None):
47
47
raise TypeError ("learning rate should be float or Variable" )
48
48
self .regularization = regularization
49
49
self ._learning_rate = learning_rate
50
+ # the learning rate type should be inferenced from loss
51
+ self ._dtype = None
50
52
# each program should have a independent learning rate
51
53
# program -> Variable(learning_rate)
52
54
self ._learning_rate_map = dict ()
@@ -77,7 +79,7 @@ def _create_global_learning_rate(self):
77
79
name = unique_name .generate ("learning_rate" ),
78
80
shape = [1 ],
79
81
value = float (self ._learning_rate ),
80
- dtype = 'float32' ,
82
+ dtype = 'float32' if self . _dtype == None else self . _dtype ,
81
83
persistable = True )
82
84
83
85
def global_learning_rate (self , program = None ):
@@ -200,6 +202,7 @@ def create_optimization_pass(self,
200
202
201
203
# Create any accumulators
202
204
program = loss .block .program
205
+ self ._dtype = loss .dtype
203
206
with program_guard (program , startup_program ):
204
207
global_block = framework .default_main_program ().global_block ()
205
208
start = len (global_block .ops )
@@ -391,7 +394,7 @@ def _create_accumulators(self, block, parameters):
391
394
beta_shape = [1 ]
392
395
self ._beta1_pow_acc = self .helper .create_global_variable (
393
396
name = unique_name .generate ('beta1_pow_acc' ),
394
- dtype = 'float32' ,
397
+ dtype = 'float32' if self . _dtype == None else self . _dtype ,
395
398
shape = beta_shape ,
396
399
lod_level = 0 ,
397
400
persistable = True )
@@ -400,7 +403,7 @@ def _create_accumulators(self, block, parameters):
400
403
401
404
self ._beta2_pow_acc = self .helper .create_global_variable (
402
405
name = unique_name .generate ('beta2_pow_acc' ),
403
- dtype = 'float32' ,
406
+ dtype = 'float32' if self . _dtype == None else self . _dtype ,
404
407
shape = beta_shape ,
405
408
lod_level = 0 ,
406
409
persistable = True )
@@ -493,7 +496,7 @@ def _create_accumulators(self, block, parameters):
493
496
beta_shape = [1 ]
494
497
self ._beta1_pow_acc = self .helper .create_global_variable (
495
498
name = unique_name .generate ('beta1_pow_acc' ),
496
- dtype = 'float32' ,
499
+ dtype = 'float32' if self . _dtype == None else self . _dtype ,
497
500
shape = beta_shape ,
498
501
lod_level = 0 ,
499
502
persistable = True )
@@ -900,8 +903,10 @@ def _add_average_apply_op(self, block, param_grad):
900
903
# param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates)
901
904
tmp = layers .sum (x = [num_accumulates , old_num_accumulates ])
902
905
sum = layers .sum (x = [sum_1 , sum_2 , sum_3 ])
903
- tmp = layers .cast (x = tmp , dtype = 'float32' )
904
- sum = layers .cast (x = sum , dtype = 'float32' )
906
+ tmp = layers .cast (
907
+ x = tmp , dtype = 'float32' if self ._dtype == None else self ._dtype )
908
+ sum = layers .cast (
909
+ x = sum , dtype = 'float32' if self ._dtype == None else self ._dtype )
905
910
layers .elementwise_div (x = sum , y = tmp , out = param )
906
911
907
912
def _add_average_restore_op (self , block , param_grad ):
0 commit comments