Skip to content

Commit 3317cf0

Browse files
authored
[cherry pick]Add pure fp16 amp_init for fleet API. (#30592)
* add fleet amp.init() * add unittest for fleet_amp_init
1 parent 619869b commit 3317cf0

File tree

3 files changed

+150
-2
lines changed

3 files changed

+150
-2
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,70 @@ def forward(self, x):
958958
# imitate target optimizer retrieval
959959
return self.user_defined_optimizer.clear_grad()
960960

961+
def amp_init(self,
962+
place,
963+
scope=None,
964+
test_program=None,
965+
use_fp16_test=False):
966+
"""
967+
Init the amp training, such as cast fp32 parameters to fp16 type.
968+
969+
Args:
970+
place(CUDAPlace): place is used to initialize
971+
fp16 parameters with fp32 values.
972+
scope(Scope): The scope is used to find fp32 parameters.
973+
test_program(Program): The program is used for testing.
974+
use_fp16_test(bool): Whether to use fp16 testing.
975+
976+
Examples:
977+
.. code-block:: python
978+
979+
import numpy as np
980+
import paddle
981+
import paddle.nn.functional as F
982+
paddle.enable_static()
983+
984+
def run_example_code():
985+
place = paddle.CUDAPlace(0)
986+
exe = paddle.static.Executor(place)
987+
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
988+
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
989+
# 1) Use fp16_guard to control the range of fp16 kernels used.
990+
with paddle.static.amp.fp16_guard():
991+
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
992+
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
993+
hidden = paddle.static.nn.fc(pool, size=10)
994+
loss = paddle.mean(hidden)
995+
# 2) Create the optimizer and set `multi_precision` to True.
996+
# Setting `multi_precision` to True can avoid the poor accuracy
997+
# or the slow convergence in a way.
998+
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
999+
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
1000+
amp_list = paddle.static.amp.CustomOpLists(
1001+
custom_black_list=['pool2d'])
1002+
# 4) The entry of Paddle AMP.
1003+
# Enable pure fp16 training by setting `use_pure_fp16` to True.
1004+
optimizer = paddle.static.amp.decorate(
1005+
optimizer,
1006+
amp_list,
1007+
init_loss_scaling=128.0,
1008+
use_dynamic_loss_scaling=True,
1009+
use_pure_fp16=True)
1010+
# If you don't use the default_startup_program(), you sholud pass
1011+
# your defined `startup_program` into `minimize`.
1012+
optimizer.minimize(loss)
1013+
exe.run(paddle.static.default_startup_program())
1014+
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
1015+
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
1016+
optimizer.amp_init(place, scope=paddle.static.global_scope())
1017+
1018+
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
1019+
run_example_code()
1020+
"""
1021+
# imitate target optimizer retrieval
1022+
return self.user_defined_optimizer.amp_init(
1023+
place, scope=None, test_program=None, use_fp16_test=False)
1024+
9611025
def _final_strategy(self):
9621026
if "valid_strategy" not in self._context:
9631027
print(

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _update_list(self):
9595
'sigmoid_cross_entropy_with_logits',
9696
'cross_entropy',
9797
'cross_entropy2',
98+
# fp16 is slower than fp32, though fp16 is supported.
99+
'lookup_table',
100+
'lookup_table_v2',
98101
}
99102

100103
# This set contains two types of ops. All ops supported fp16 calculation. One
@@ -115,8 +118,6 @@ def _update_list(self):
115118
'layer_norm',
116119
'tanh',
117120
'sigmoid',
118-
'lookup_table',
119-
'lookup_table_v2',
120121
'top_k',
121122
'pool2d',
122123
'pool3d',
@@ -284,6 +285,9 @@ def _update_list(self):
284285
'generate_proposals',
285286
'generate_proposal_labels',
286287
'generate_mask_labels',
288+
# fp16 is slower than fp32, though fp16 is supported.
289+
'lookup_table',
290+
'lookup_table_v2',
287291
}
288292

289293
CustomOpLists = AutoMixedPrecisionLists
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.distributed.fleet.base.role_maker as role_maker
17+
import paddle.distributed.fleet as fleet
18+
import paddle.fluid as fluid
19+
import unittest
20+
import paddle.nn.functional as F
21+
import numpy as np
22+
23+
paddle.enable_static()
24+
25+
26+
def gen_data():
27+
return {
28+
"x": np.random.random(size=(128, 32)).astype('float32'),
29+
"y": np.random.randint(
30+
2, size=(128, 1)).astype('int64')
31+
}
32+
33+
34+
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
35+
fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
36+
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
37+
prediction = paddle.static.nn.fc(x=[fc_2],
38+
size=label_dim,
39+
activation='softmax')
40+
cost = F.cross_entropy(input=prediction, label=input_y)
41+
avg_cost = paddle.mean(x=cost)
42+
return avg_cost
43+
44+
45+
class TestFleetAMPInit(unittest.TestCase):
46+
def test_fleet_amp_init(self):
47+
if not fluid.core.is_compiled_with_cuda():
48+
return
49+
input_x = paddle.static.data(
50+
name="x", shape=[None, 32], dtype='float32')
51+
input_y = paddle.static.data(name="y", shape=[None, 1], dtype='int64')
52+
53+
cost = mlp(input_x, input_y)
54+
optimizer = paddle.optimizer.Momentum(
55+
learning_rate=0.001,
56+
momentum=0.9,
57+
weight_decay=fluid.regularizer.L2Decay(1e-4),
58+
multi_precision=True)
59+
60+
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
61+
fleet.init(role)
62+
63+
optimizer = paddle.static.amp.decorate(optimizer)
64+
optimizer = fleet.distributed_optimizer(optimizer)
65+
optimizer.minimize(cost)
66+
place = paddle.CUDAPlace(0)
67+
68+
exe = paddle.static.Executor(place)
69+
exe.run(paddle.static.default_startup_program())
70+
optimizer.amp_init(place, use_fp16_test=True)
71+
72+
step = 1
73+
for i in range(step):
74+
cost_val = exe.run(program=paddle.static.default_main_program(),
75+
feed=gen_data(),
76+
fetch_list=[cost.name])
77+
78+
79+
if __name__ == '__main__':
80+
unittest.main()

0 commit comments

Comments
 (0)