Skip to content

Commit 03babe1

Browse files
authored
Fleet distributed strategy support pure fp16 (#30754) (#31238)
1 parent 188bcbb commit 03babe1

13 files changed

+178
-20
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ message AMPConfig {
4444
repeated string custom_white_list = 7;
4545
repeated string custom_black_list = 8;
4646
repeated string custom_black_varnames = 9;
47+
optional bool use_pure_fp16 = 10 [ default = false ];
48+
optional bool use_fp16_guard = 11 [ default = true ];
4749
}
4850

4951
message LocalSGDConfig {

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def assign_configs_value(msg, config):
4949
for key in config:
5050
for f in fields:
5151
if key == f.name:
52+
# LABEL_OPTIONAL = 1
53+
# LABEL_REPEATED = 3
54+
# LABEL_REQUIRED = 2
5255
if f.label == 3:
5356
getattr(msg, f.name).extend(config[f.name])
5457
elif f.label == 1 or f.label == 2:
@@ -366,7 +369,14 @@ def amp_configs(self):
366369
367370
custom_black_list(list[str]): Users' custom black list which forbidden execution fp16.
368371
369-
Examples:
372+
custom_black_varnames(list[str]): Users' custom black varibles' names.
373+
374+
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
375+
376+
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
377+
Default True. Only takes effect when `use_pure_fp16` is turned on.
378+
379+
Examples 1:
370380
371381
.. code-block:: python
372382
@@ -376,6 +386,19 @@ def amp_configs(self):
376386
strategy.amp_configs = {
377387
"init_loss_scaling": 32768,
378388
"custom_white_list": ['conv2d']}
389+
390+
Examples 2:
391+
392+
.. code-block:: python
393+
394+
import paddle.distributed.fleet as fleet
395+
strategy = fleet.DistributedStrategy()
396+
strategy.amp = True
397+
# pure fp16
398+
strategy.amp_configs = {
399+
"init_loss_scaling": 32768,
400+
"use_pure_fp16": True
401+
}
379402
"""
380403
return get_msg_dict(self.strategy.amp_configs)
381404

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def init(self, role_maker=None, is_collective=False, strategy=None):
196196
else:
197197
if isinstance(role_maker, RoleMakerBase):
198198
self._role_maker = role_maker
199+
self._is_collective = role_maker._is_collective
199200
else:
200201
raise ValueError(
201202
"`role_maker` should be subclass of `RoleMakerBase`, but got {}".
@@ -1022,9 +1023,22 @@ def run_example_code():
10221023
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
10231024
run_example_code()
10241025
"""
1026+
10251027
# imitate target optimizer retrieval
1026-
return self.user_defined_optimizer.amp_init(place, scope, test_program,
1027-
use_fp16_test)
1028+
amp_optimizer = None
1029+
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
1030+
if hasattr(optimizer, 'amp_init'):
1031+
amp_optimizer = optimizer
1032+
break
1033+
1034+
if amp_optimizer is None:
1035+
if hasattr(self.user_defined_optimizer, 'amp_init'):
1036+
amp_optimizer = self.user_defined_optimizer
1037+
1038+
assert amp_optimizer is not None, \
1039+
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
1040+
1041+
return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
10281042

10291043
def _final_strategy(self):
10301044
if "valid_strategy" not in self._context:

python/paddle/distributed/fleet/base/strategy_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def __init__(self):
129129
self._meta_optimizer_candidates = []
130130
self._graph_optimizer_candidates = []
131131

132+
def _get_applied_meta_optimizer(self):
133+
return self._meta_optimizers
134+
132135
def _get_applied_meta_list(self):
133136
return [type(opt).__name__ for opt in self._meta_optimizers]
134137

python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def _init_wrapped_opt(self):
5050
self.inner_opt, amp_lists, config['init_loss_scaling'],
5151
config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'],
5252
config['incr_ratio'], config['decr_ratio'],
53-
config['use_dynamic_loss_scaling'])
53+
config['use_dynamic_loss_scaling'], config['use_pure_fp16'],
54+
config['use_fp16_guard'])
5455

5556
# if worker_num > 1, all cards will communication with each other,
5657
# add is_distributed to optimize amp, overlap communication and
@@ -112,3 +113,11 @@ def minimize_impl(self,
112113
self.wrapped_opt.minimize(loss, startup_program,
113114
parameter_list, no_grad_set)
114115
return optimize_ops, params_grads
116+
117+
def amp_init(self,
118+
place,
119+
scope=None,
120+
test_program=None,
121+
use_fp16_test=False):
122+
return self.wrapped_opt.amp_init(place, scope, test_program,
123+
use_fp16_test)

python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def _try_to_compile(self, startup_program, main_program, loss):
165165
main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks
166166

167167
# TODO(guru4elephant): should be an independent optimizer
168-
self._setup_nccl_op(startup_program, main_program, local_build_strategy)
168+
if worker_num > 1:
169+
self._setup_nccl_op(startup_program, main_program,
170+
local_build_strategy)
169171

170172
local_build_strategy.num_trainers = self.role_maker._worker_num()
171173
local_build_strategy.trainer_id = self.role_maker._worker_index()

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3)
4848
list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
4949
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
5050
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
51+
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
5152
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
5253
list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer)
5354
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
@@ -506,6 +507,7 @@ if(WITH_DISTRIBUTE)
506507
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
507508
py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS})
508509
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
510+
py_test_modules(test_fleet_amp_init MODULES test_fleet_amp_init ENVS ${dist_ENVS})
509511
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
510512
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
511513
py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS})

python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ def set_strategy(self, strategy, name):
8888
"custom_white_list": ['softmax'],
8989
"custom_black_list": ['tanh'],
9090
}
91+
elif name == 'pure_fp16':
92+
strategy.amp = True
93+
strategy.amp_configs = {
94+
"init_loss_scaling": 32768,
95+
"decr_every_n_nan_or_inf": 2,
96+
"incr_every_n_steps": 1000,
97+
"incr_ratio": 2.0,
98+
"use_dynamic_loss_scaling": True,
99+
"decr_ratio": 0.5,
100+
"custom_white_list": ['softmax'],
101+
"custom_black_list": ['tanh'],
102+
"use_pure_fp16": True,
103+
"use_fp16_guard": False,
104+
}
105+
91106
elif name == 'dgc':
92107
strategy.dgc = True
93108
strategy.dgc_configs = {

python/paddle/fluid/tests/unittests/test_fleet_amp_init.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,34 +46,88 @@ class TestFleetAMPInit(unittest.TestCase):
4646
def test_fleet_amp_init(self):
4747
if not fluid.core.is_compiled_with_cuda():
4848
return
49-
input_x = paddle.static.data(
50-
name="x", shape=[None, 32], dtype='float32')
51-
input_y = paddle.static.data(name="y", shape=[None, 1], dtype='int64')
5249

53-
cost = mlp(input_x, input_y)
54-
optimizer = paddle.optimizer.Momentum(
55-
learning_rate=0.001,
56-
momentum=0.9,
57-
weight_decay=fluid.regularizer.L2Decay(1e-4),
58-
multi_precision=True)
50+
main_program = paddle.static.Program()
51+
startup_program = paddle.static.Program()
5952

6053
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
6154
fleet.init(role)
6255

63-
optimizer = paddle.static.amp.decorate(optimizer)
64-
optimizer = fleet.distributed_optimizer(optimizer)
65-
optimizer.minimize(cost)
56+
with paddle.static.program_guard(main_program, startup_program):
57+
input_x = paddle.static.data(
58+
name="x", shape=[None, 32], dtype='float32')
59+
input_y = paddle.static.data(
60+
name="y", shape=[None, 1], dtype='int64')
61+
62+
cost = mlp(input_x, input_y)
63+
optimizer = paddle.optimizer.Momentum(
64+
learning_rate=0.001,
65+
momentum=0.9,
66+
weight_decay=fluid.regularizer.L2Decay(1e-4),
67+
multi_precision=True)
68+
69+
optimizer = paddle.static.amp.decorate(optimizer)
70+
optimizer = fleet.distributed_optimizer(optimizer)
71+
optimizer.minimize(cost)
72+
6673
place = paddle.CUDAPlace(0)
6774

6875
exe = paddle.static.Executor(place)
69-
exe.run(paddle.static.default_startup_program())
76+
exe.run(startup_program)
7077
optimizer.amp_init(place)
7178

7279
step = 1
7380
for i in range(step):
74-
cost_val = exe.run(program=paddle.static.default_main_program(),
81+
cost_val = exe.run(program=main_program,
82+
feed=gen_data(),
83+
fetch_list=[cost.name])
84+
85+
def test_fleet_amp_meta_optimizer_init(self):
86+
if not fluid.core.is_compiled_with_cuda():
87+
return
88+
89+
main_program = paddle.static.Program()
90+
startup_program = paddle.static.Program()
91+
92+
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
93+
fleet.init(role)
94+
95+
with paddle.static.program_guard(main_program, startup_program):
96+
input_x = paddle.static.data(
97+
name="x", shape=[None, 32], dtype='float32')
98+
input_y = paddle.static.data(
99+
name="y", shape=[None, 1], dtype='int64')
100+
101+
cost = mlp(input_x, input_y)
102+
optimizer = paddle.optimizer.Momentum(
103+
learning_rate=0.001,
104+
momentum=0.9,
105+
weight_decay=fluid.regularizer.L2Decay(1e-4),
106+
multi_precision=True)
107+
108+
strategy = paddle.distributed.fleet.DistributedStrategy()
109+
strategy.amp = True
110+
strategy.amp_configs = {'use_pure_fp16': True}
111+
strategy.gradient_merge = True
112+
strategy.gradient_merge_configs = {"k_steps": 2}
113+
114+
optimizer = fleet.distributed_optimizer(optimizer, strategy)
115+
optimizer.minimize(cost)
116+
117+
print(fleet._get_applied_meta_list())
118+
119+
place = paddle.CUDAPlace(0)
120+
121+
exe = paddle.static.Executor(place)
122+
exe.run(startup_program)
123+
optimizer.amp_init(place)
124+
125+
step = 3
126+
for i in range(step):
127+
cost_val = exe.run(program=main_program,
75128
feed=gen_data(),
76129
fetch_list=[cost.name])
130+
print(cost_val)
77131

78132

79133
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ def test_amp_optimizer(self):
9393
self.assertIn('cast', ops)
9494
self.assertIn('check_finite_and_unscale', ops)
9595

96+
def test_pure_fp16_optimizer(self):
97+
""" test pure fp16 """
98+
train_prog, startup_prog = fluid.Program(), fluid.Program()
99+
avg_cost, strategy = self.net(train_prog, startup_prog)
100+
self.set_strategy(strategy, 'pure_fp16')
101+
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
102+
103+
params = train_prog.all_parameters()
104+
for param in train_prog.all_parameters():
105+
self.assertEqual(param.dtype, fluid.core.VarDesc.VarType.FP16)
106+
107+
ops = [op.type for op in avg_cost.block.ops]
108+
self.assertIn('cast', ops)
109+
self.assertIn('check_finite_and_unscale', ops)
110+
96111
def test_amp_distributed_optimizer(self):
97112
""" test amp when distributed """
98113
train_prog, startup_prog = fluid.Program(), fluid.Program()

0 commit comments

Comments
 (0)