@@ -1121,21 +1121,23 @@ def __init__(self,
1121
1121
self ._num_trainers = num_trainers
1122
1122
self ._clip_norm = local_grad_clip_norm * (num_trainers ** - 0.5 )
1123
1123
1124
- self ._get_dgc_regularization_param ()
1124
+ self .regular_type , self .regular_coeff = self ._get_regularization_param (
1125
+ self .regularization )
1125
1126
1126
- def _get_dgc_regularization_param (self ):
1127
- self . regular_coeff = 0. 0
1128
- self . regular_type = 0
1127
+ def _get_regularization_param (self , regularization ):
1128
+ regular_type = 0
1129
+ regular_coeff = 0. 0
1129
1130
1130
- if self . regularization is not None :
1131
- self . regular_coeff = self . regularization ._regularization_coeff
1131
+ if regularization is not None :
1132
+ regular_coeff = regularization ._regularization_coeff
1132
1133
from .regularizer import L1Decay , L2Decay
1133
- if isinstance (self . regularization , L1Decay ):
1134
- self . regular_type = 1
1135
- elif isinstance (self . regularization , L2Decay ):
1136
- self . regular_type = 2
1134
+ if isinstance (regularization , L1Decay ):
1135
+ regular_type = 1
1136
+ elif isinstance (regularization , L2Decay ):
1137
+ regular_type = 2
1137
1138
else :
1138
1139
assert False , 'regularization must be None|L1Decay|L2Deacy'
1140
+ return regular_type , regular_coeff
1139
1141
1140
1142
def _is_use_dgc (self , param_var , grad_var ):
1141
1143
var_numel = abs (reduce (lambda x , y : x * y , param_var .shape ))
@@ -1336,6 +1338,13 @@ def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
1336
1338
block = framework .default_main_program ().global_block ()
1337
1339
op_maker = core .op_proto_and_checker_maker
1338
1340
1341
+ regular_type = self .regular_type
1342
+ regular_coeff = self .regular_coeff
1343
+ # The regularizer of the Parameters have higher priority
1344
+ if param_var .regularizer is not None :
1345
+ regular_type , regular_coeff = self ._get_regularization_param (
1346
+ param_var .regularizer )
1347
+
1339
1348
dgc_op = block .append_op (
1340
1349
type = "dgc" ,
1341
1350
inputs = {
@@ -1360,8 +1369,8 @@ def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
1360
1369
"use_nesterov" : self ._use_nesterov ,
1361
1370
"rampup_begin_step" : float (self ._rampup_begin_step ),
1362
1371
"rampup_step" : float (self ._rampup_step ),
1363
- "regular_coeff" : float (self . regular_coeff ),
1364
- "regular_type" : int (self . regular_type ),
1372
+ "regular_coeff" : float (regular_coeff ),
1373
+ "regular_type" : int (regular_type ),
1365
1374
},
1366
1375
stop_gradient = True )
1367
1376
0 commit comments