Skip to content

Commit 9b64d63

Browse files
author
Yibing Liu
authored
Fix the global_step & continuous applying error in EMA (#22090) (#22130)
* Fix the global_step & continuous applying error in EMA * Fix for step 0 & add unit test test=release/1.6
1 parent 70c073a commit 9b64d63

File tree

2 files changed

+102
-6
lines changed

2 files changed

+102
-6
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,7 +3017,7 @@ class ExponentialMovingAverage(object):
30173017
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
30183018
optimizer.minimize(cost)
30193019
3020-
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
3020+
global_steps = fluid.layers.autoincreased_step_counter()
30213021
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
30223022
ema.update()
30233023
@@ -3055,6 +3055,7 @@ def __init__(self, decay=0.999, thres_steps=None, name=None):
30553055
self._name = name if name is not None else ''
30563056
self._decay_var = self._get_ema_decay()
30573057

3058+
self._step_counter_name = "@EMA_STEP_COUNTER@"
30583059
self._params_tmps = []
30593060
for param in default_main_program().global_block().all_parameters():
30603061
if param.do_model_average != False:
@@ -3075,14 +3076,16 @@ def __init__(self, decay=0.999, thres_steps=None, name=None):
30753076
self.apply_program = Program()
30763077
block = self.apply_program.global_block()
30773078
with program_guard(main_program=self.apply_program):
3078-
decay_pow = self._get_decay_pow(block)
3079+
decay_pow, global_step = self._get_decay_pow(block)
30793080
for param, tmp in self._params_tmps:
30803081
param = block._clone_variable(param)
30813082
tmp = block._clone_variable(tmp)
30823083
ema = block._clone_variable(self._ema_vars[param.name])
30833084
layers.assign(input=param, output=tmp)
30843085
# bias correction
3085-
ema = ema / (1.0 - decay_pow)
3086+
with layers.control_flow.Switch() as switch:
3087+
with switch.case(global_step > 0):
3088+
layers.assign(output=ema, input=ema / (1.0 - decay_pow))
30863089
layers.assign(input=ema, output=param)
30873090

30883091
self.restore_program = Program()
@@ -3115,10 +3118,16 @@ def _get_ema_decay(self):
31153118
return decay_var
31163119

31173120
def _get_decay_pow(self, block):
3118-
global_steps = layers.learning_rate_scheduler._decay_step_counter()
3121+
global_step = layers.create_global_var(
3122+
name=self._step_counter_name,
3123+
shape=[1],
3124+
value=0,
3125+
dtype='int64',
3126+
persistable=True)
3127+
global_step = layers.cast(global_step, "float32")
31193128
decay_var = block._clone_variable(self._decay_var)
3120-
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
3121-
return decay_pow_acc
3129+
decay_pow_acc = layers.elementwise_pow(decay_var, global_step)
3130+
return decay_pow_acc, global_step
31223131

31233132
def _create_ema_vars(self, param):
31243133
param_ema = layers.create_global_var(
@@ -3135,6 +3144,8 @@ def update(self):
31353144
Update Exponential Moving Average. Should only call this method in
31363145
train program.
31373146
"""
3147+
global_step = layers.autoincreased_step_counter(
3148+
counter_name=self._step_counter_name)
31383149
param_master_emas = []
31393150
for param, tmp in self._params_tmps:
31403151
with param.block.program._optimized_guard(
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import paddle.fluid as fluid
20+
21+
22+
class TestExponentialMovingAverage(unittest.TestCase):
23+
def setUp(self):
24+
self._places = [fluid.CPUPlace()]
25+
if fluid.core.is_compiled_with_cuda():
26+
self._places.append(fluid.CUDAPlace(0))
27+
self._ema_decay = 0.999
28+
self._param_name = "fc.weight"
29+
30+
self._train_program = fluid.Program()
31+
self._startup_prog = fluid.Program()
32+
with fluid.program_guard(self._train_program, self._startup_prog):
33+
with fluid.unique_name.guard():
34+
data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
35+
hidden = fluid.layers.fc(input=data,
36+
size=10,
37+
param_attr=self._param_name)
38+
cost = fluid.layers.mean(hidden)
39+
40+
self._test_program = fluid.default_main_program().clone(
41+
for_test=True)
42+
43+
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
44+
optimizer.minimize(cost)
45+
46+
self._ema = fluid.optimizer.ExponentialMovingAverage(
47+
self._ema_decay)
48+
self._ema.update()
49+
50+
def train(self, place):
51+
exe = fluid.Executor(place)
52+
exe.run(self._startup_prog)
53+
54+
params = []
55+
for pass_id in range(2):
56+
for batch_id in range(3):
57+
data = np.random.random(size=(10, 5)).astype('float32')
58+
tmp_param = np.array(fluid.global_scope().find_var(
59+
self._param_name).get_tensor())
60+
exe.run(program=self._train_program, feed={'x': data})
61+
tmp_param = np.array(fluid.global_scope().find_var(
62+
self._param_name).get_tensor())
63+
params.append(tmp_param)
64+
65+
with self._ema.apply(exe):
66+
final_ema = np.array(fluid.global_scope().find_var(self._param_name)
67+
.get_tensor())
68+
data = np.random.random(size=(10, 5)).astype('float32')
69+
exe.run(program=self._test_program, feed={'x': data})
70+
return params, final_ema
71+
72+
def test_check_ema(self):
73+
for place in self._places:
74+
params, final_ema = self.train(place)
75+
manu_ema = np.zeros_like(final_ema)
76+
if len(params) > 0:
77+
for param in params:
78+
manu_ema = self._ema_decay * manu_ema + (1 - self._ema_decay
79+
) * param
80+
manu_ema = manu_ema / (1.0 - self._ema_decay**len(params))
81+
self.assertTrue(np.allclose(manu_ema, final_ema))
82+
83+
84+
if __name__ == '__main__':
85+
unittest.main()

0 commit comments

Comments
 (0)