@@ -85,12 +85,20 @@ def _append_optimize_op(self, block, param_and_grad):
8585 if getattr (param_and_grad [0 ], 'is_sparse_grad' , None ):
8686 index = getattr (param_and_grad [0 ], 'index' , None )
8787 axis = getattr (param_and_grad [0 ], 'axis' , None )
88- _ , _ = paddle ._C_ops .sparse_momentum (
89- param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr ,
90- param_and_grad [0 ], velocity_acc , 'mu' , self ._momentum ,
91- 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
92- self ._regularization_method , 'regularization_coeff' ,
93- self ._regularization_coeff , 'axis' , axis )
88+ try :
89+ _ , _ = paddle ._C_ops .sparse_momentum (
90+ param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr ,
91+ param_and_grad [0 ], velocity_acc , 'mu' , self ._momentum ,
92+ 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
93+ self ._regularization_method , 'regularization_coeff' ,
94+ self ._regularization_coeff , 'axis' , axis )
95+ except :
96+ _ , _ , _ = paddle ._C_ops .sparse_momentum (
97+ param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr , master_weight ,
98+ param_and_grad [0 ], velocity_acc , master_weight , 'mu' , self ._momentum ,
99+ 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
100+ self ._regularization_method , 'regularization_coeff' ,
101+ self ._regularization_coeff , 'axis' , axis , 'multi_precision' , find_master )
94102 else :
95103 _ , _ , _ = paddle ._C_ops .momentum (
96104 param_and_grad [0 ], param_and_grad [1 ], velocity_acc , lr ,
0 commit comments