Skip to content

Commit 6125a58

Browse files
author
Yibing Liu
authored
Fix ema's example & fp16 update (#18273) (#18275)
test=release/1.5
1 parent 575bc57 commit 6125a58

File tree

1 file changed

+63
-33
lines changed

1 file changed

+63
-33
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,36 +2458,50 @@ class ExponentialMovingAverage(object):
24582458
Examples:
24592459
24602460
.. code-block:: python
2461-
2462-
import paddle.fluid as fluid
2463-
2464-
data = fluid.layers.data(name='x', shape=[5], dtype='float32')
2465-
hidden = fluid.layers.fc(input=data, size=10)
2466-
cost = fluid.layers.mean(hidden)
2467-
2468-
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
2469-
optimizer.minimize(cost)
2470-
2471-
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
2472-
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
2473-
ema.update()
2474-
2475-
# pseudo code
2476-
for pass_id in range(args.pass_num):
2477-
for data in train_reader():
2478-
exe.run(fluid.default_main_program()...)
2479-
2480-
# usage 1
2481-
with ema.apply(exe):
2482-
for data in test_reader():
2483-
exe.run(inference_program...)
2484-
2485-
# usage 2
2486-
with ema.apply(exe, need_restore=False):
2487-
for data in test_reader():
2488-
exe.run(inference_program...)
2489-
...
2490-
ema.restore(exe)
2461+
2462+
import numpy
2463+
import paddle
2464+
import paddle.fluid as fluid
2465+
2466+
data = fluid.layers.data(name='x', shape=[5], dtype='float32')
2467+
hidden = fluid.layers.fc(input=data, size=10)
2468+
cost = fluid.layers.mean(hidden)
2469+
2470+
test_program = fluid.default_main_program().clone(for_test=True)
2471+
2472+
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
2473+
optimizer.minimize(cost)
2474+
2475+
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
2476+
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
2477+
ema.update()
2478+
2479+
place = fluid.CPUPlace()
2480+
exe = fluid.Executor(place)
2481+
exe.run(fluid.default_startup_program())
2482+
2483+
for pass_id in range(3):
2484+
for batch_id in range(6):
2485+
data = numpy.random.random(size=(10, 5)).astype('float32')
2486+
exe.run(program=fluid.default_main_program(),
2487+
feed={'x': data},
2488+
fetch_list=[cost.name])
2489+
2490+
# usage 1
2491+
with ema.apply(exe):
2492+
data = numpy.random.random(size=(10, 5)).astype('float32')
2493+
exe.run(program=test_program,
2494+
feed={'x': data},
2495+
fetch_list=[hidden.name])
2496+
2497+
2498+
# usage 2
2499+
with ema.apply(exe, need_restore=False):
2500+
data = numpy.random.random(size=(10, 5)).astype('float32')
2501+
exe.run(program=test_program,
2502+
feed={'x': data},
2503+
fetch_list=[hidden.name])
2504+
ema.restore(exe)
24912505
"""
24922506

24932507
def __init__(self, decay=0.999, thres_steps=None, name=None):
@@ -2576,13 +2590,29 @@ def update(self):
25762590
Update Exponential Moving Average. Should only call this method in
25772591
train program.
25782592
"""
2593+
param_master_emas = []
25792594
for param, tmp in self._params_tmps:
25802595
with param.block.program._optimized_guard(
25812596
[param, tmp]), name_scope('moving_average'):
25822597
param_ema = self._ema_vars[param.name]
2583-
ema_t = param_ema * self._decay_var + param * (1 -
2584-
self._decay_var)
2585-
layers.assign(input=ema_t, output=param_ema)
2598+
if self._ema_vars.has_key(param.name + '.master'):
2599+
master_ema = self._ema_vars[param.name + '.master']
2600+
param_master_emas.append([param_ema, master_ema])
2601+
else:
2602+
ema_t = param_ema * self._decay_var + param * (
2603+
1 - self._decay_var)
2604+
layers.assign(input=ema_t, output=param_ema)
2605+
2606+
# for fp16 params
2607+
for param_ema, master_ema in param_master_emas:
2608+
default_main_program().global_block().append_op(
2609+
type="cast",
2610+
inputs={"X": master_ema},
2611+
outputs={"Out": param_ema},
2612+
attrs={
2613+
"in_dtype": master_ema.dtype,
2614+
"out_dtype": param_ema.dtype
2615+
})
25862616

25872617
@signature_safe_contextmanager
25882618
def apply(self, executor, need_restore=True):

0 commit comments

Comments
 (0)