Skip to content

Commit b8a593e

Browse files
authored
Use correct master weights in AdamW. (#30895) (#31142)
* Use correct master weights in AdamW. * Just modify the master weight. * Update for CI Coverage.
1 parent 37b7182 commit b8a593e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
9797
test_program = train_program.clone(for_test=True)
9898

9999
if use_adam:
100-
optimizer = paddle.optimizer.Adam(
100+
optimizer = paddle.optimizer.AdamW(
101101
learning_rate=0.001,
102102
epsilon=1e-8,
103103
weight_decay=0.0,

python/paddle/optimizer/adamw.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .optimizer import Optimizer
1616
from .adam import Adam
17+
from ..fluid import core
1718
from ..fluid import framework
1819
from ..fluid.dygraph import base as imperative_base
1920
import paddle
@@ -182,8 +183,16 @@ def _append_decoupled_weight_decay(self, block, param_and_grad):
182183
decay_coeff = 1.0 - learning_rate * self._coeff
183184
self._lr_to_coeff[learning_rate] = decay_coeff
184185

185-
scaled_param = param * decay_coeff
186-
paddle.fluid.layers.assign(input=scaled_param, output=param)
186+
find_master = (self._multi_precision and
187+
param.dtype == core.VarDesc.VarType.FP16)
188+
if find_master:
189+
master_weight = self._master_weights[param.name]
190+
scaled_param = master_weight * decay_coeff
191+
paddle.fluid.layers.assign(
192+
input=scaled_param, output=master_weight)
193+
else:
194+
scaled_param = param * decay_coeff
195+
paddle.fluid.layers.assign(input=scaled_param, output=param)
187196

188197
def _append_optimize_op(self, block, param_and_grad):
189198
self._append_decoupled_weight_decay(block, param_and_grad)

0 commit comments

Comments
 (0)