Skip to content

Commit 7a92e75

Browse files
authored
Fix dgc param regularizer, test=release/1.7 (#22903)
1 parent d9fe528 commit 7a92e75

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,21 +1121,23 @@ def __init__(self,
11211121
self._num_trainers = num_trainers
11221122
self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5)
11231123

1124-
self._get_dgc_regularization_param()
1124+
self.regular_type, self.regular_coeff = self._get_regularization_param(
1125+
self.regularization)
11251126

1126-
def _get_dgc_regularization_param(self):
1127-
self.regular_coeff = 0.0
1128-
self.regular_type = 0
1127+
def _get_regularization_param(self, regularization):
1128+
regular_type = 0
1129+
regular_coeff = 0.0
11291130

1130-
if self.regularization is not None:
1131-
self.regular_coeff = self.regularization._regularization_coeff
1131+
if regularization is not None:
1132+
regular_coeff = regularization._regularization_coeff
11321133
from .regularizer import L1Decay, L2Decay
1133-
if isinstance(self.regularization, L1Decay):
1134-
self.regular_type = 1
1135-
elif isinstance(self.regularization, L2Decay):
1136-
self.regular_type = 2
1134+
if isinstance(regularization, L1Decay):
1135+
regular_type = 1
1136+
elif isinstance(regularization, L2Decay):
1137+
regular_type = 2
11371138
else:
11381139
assert False, 'regularization must be None|L1Decay|L2Deacy'
1140+
return regular_type, regular_coeff
11391141

11401142
def _is_use_dgc(self, param_var, grad_var):
11411143
var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
@@ -1336,6 +1338,13 @@ def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
13361338
block = framework.default_main_program().global_block()
13371339
op_maker = core.op_proto_and_checker_maker
13381340

1341+
regular_type = self.regular_type
1342+
regular_coeff = self.regular_coeff
1343+
# The regularizer of the Parameters have higher priority
1344+
if param_var.regularizer is not None:
1345+
regular_type, regular_coeff = self._get_regularization_param(
1346+
param_var.regularizer)
1347+
13391348
dgc_op = block.append_op(
13401349
type="dgc",
13411350
inputs={
@@ -1360,8 +1369,8 @@ def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
13601369
"use_nesterov": self._use_nesterov,
13611370
"rampup_begin_step": float(self._rampup_begin_step),
13621371
"rampup_step": float(self._rampup_step),
1363-
"regular_coeff": float(self.regular_coeff),
1364-
"regular_type": int(self.regular_type),
1372+
"regular_coeff": float(regular_coeff),
1373+
"regular_type": int(regular_type),
13651374
},
13661375
stop_gradient=True)
13671376

python/paddle/fluid/tests/unittests/test_dgc_optimizer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def check_dgc_momentum_optimizer(self,
4444
shape=[dims[0], dims[1]],
4545
lod_level=0,
4646
name="mul.x",
47-
optimize_attr={'learning_rate': 1.1})
47+
optimize_attr={'learning_rate': 1.1},
48+
regularizer=None if regularization is not None else
49+
regularizer.L2DecayRegularizer(2e-4))
4850
mul_y = block.create_var(
4951
dtype="float32",
5052
shape=[dims[1], dims[2]],
@@ -102,6 +104,14 @@ def check_dgc_momentum_optimizer(self,
102104
self.assertEqual(init_ops[0].type, "fill_constant")
103105
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
104106

107+
# check dgc op regularization coeff
108+
train_ops = program.global_block().ops
109+
for op in train_ops:
110+
if op.type == "dgc":
111+
coeff = 2e-4 if regularization is None else 1e-4
112+
self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
113+
print("dgc regular_coeff=" + str(coeff))
114+
105115
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
106116
program_to_code(program, fout=f)
107117

@@ -116,6 +126,10 @@ def test_momentum_with_dgc(self):
116126
name="dgc_momentum",
117127
regularization=regularizer.L2Decay(1e-4))
118128

129+
# check param.regularizer in dgc
130+
self.check_dgc_momentum_optimizer(
131+
dims=[16, 1024, 8], name="dgc_momentum")
132+
119133

120134
if __name__ == '__main__':
121135
unittest.main()

0 commit comments

Comments
 (0)