Skip to content

Commit bb5963d

Browse files
lilong12FeixLiuyouth123
authored
[CP] add a strategy to run program with fleet (#33511)
* Add raw program meta optimizer (#32597) * add raw program, test=develop * add precision unitest for executor all reduce (#33339) * fix dp (#33297) Co-authored-by: Yuang Liu <[email protected]> Co-authored-by: 李季 <[email protected]>
1 parent 7be50f9 commit bb5963d

File tree

10 files changed

+592
-5
lines changed

10 files changed

+592
-5
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ message DistributedStrategy {
175175
optional float last_comm_group_size_MB = 27 [ default = 1 ];
176176
optional bool find_unused_parameters = 28 [ default = false ];
177177
optional bool tensor_parallel = 29 [ default = false ];
178+
optional bool without_graph_optimization = 30 [ default = false ];
178179

179180
optional RecomputeConfig recompute_configs = 101;
180181
optional AMPConfig amp_configs = 102;

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,32 @@ def sharding_configs(self, configs):
827827
"sharding_configs")
828828
assign_configs_value(self.strategy.sharding_configs, configs)
829829

830+
@property
831+
def without_graph_optimization(self):
832+
"""
833+
Run program using Executor other than ParallelExecutor.
834+
835+
Examples:
836+
837+
.. code-block:: python
838+
839+
import paddle.distributed.fleet as fleet
840+
strategy = fleet.DistributedStrategy()
841+
strategy.without_graph_optimization = True
842+
843+
"""
844+
return self.strategy.without_graph_optimization
845+
846+
@without_graph_optimization.setter
847+
@is_strict_auto
848+
def without_graph_optimization(self, flag):
849+
if isinstance(flag, bool):
850+
self.strategy.without_graph_optimization = flag
851+
else:
852+
print(
853+
"WARNING: without_graph_optimization should have value of bool type"
854+
)
855+
830856
@property
831857
def pipeline(self):
832858
"""

python/paddle/distributed/fleet/meta_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@
2828
from .dygraph_optimizer import HybridParallelOptimizer
2929
from .dygraph_optimizer import HybridParallelGradScaler
3030
from .tensor_parallel_optimizer import TensorParallelOptimizer
31+
from .raw_program_optimizer import RawProgramOptimizer
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
14+
from __future__ import print_function
15+
from __future__ import division
16+
import os
17+
18+
import paddle.fluid as fluid
19+
from paddle.fluid import core, unique_name
20+
from ..base.private_helper_function import wait_server_ready
21+
from .meta_optimizer_base import MetaOptimizerBase
22+
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_loss_grad_op, is_backward_op, is_optimizer_op
23+
24+
25+
class RawProgramOptimizer(MetaOptimizerBase):
26+
def __init__(self, optimizer):
27+
super(RawProgramOptimizer, self).__init__(optimizer)
28+
self.inner_opt = optimizer
29+
self.meta_optimizers_white_list = [
30+
"RecomputeOptimizer",
31+
"AMPOptimizer",
32+
]
33+
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
34+
self.global_ring_id = 0
35+
36+
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
37+
user_defined_strategy):
38+
super(RawProgramOptimizer, self)._set_basic_info(
39+
loss, role_maker, user_defined_optimizer, user_defined_strategy)
40+
self.without_graph_optimization = user_defined_strategy.without_graph_optimization
41+
42+
def _can_apply(self):
43+
if not self.role_maker._is_collective:
44+
return False
45+
46+
if self.without_graph_optimization == True:
47+
return True
48+
return False
49+
50+
def _disable_strategy(self, dist_strategy):
51+
dist_strategy.without_graph_optimization = False
52+
53+
def _enable_strategy(self, dist_strategy, context):
54+
dist_strategy.without_graph_optimization = True
55+
56+
def _broadcast_params(self, ring_id):
57+
block = self.startup_program.global_block()
58+
param = None
59+
for param in block.iter_parameters():
60+
if param.is_distributed:
61+
continue
62+
63+
block.append_op(
64+
type='c_broadcast',
65+
inputs={'X': param},
66+
outputs={'Out': param},
67+
attrs={
68+
'ring_id': ring_id,
69+
'root': 0,
70+
OP_ROLE_KEY: OpRole.Forward
71+
})
72+
73+
if not param: return # no parameter on this device
74+
block.append_op(
75+
type='c_sync_comm_stream',
76+
inputs={'X': param},
77+
outputs={'Out': param},
78+
attrs={'ring_id': ring_id,
79+
OP_ROLE_KEY: OpRole.Forward})
80+
81+
def _get_process_group_info(self):
82+
# global ring info
83+
self.global_endpoints = self.endpoints
84+
self.global_rank = self.rank
85+
self.global_nranks = self.nranks
86+
87+
def _init_process_group(self):
88+
self._get_process_group_info()
89+
collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
90+
# Create global ring for all gpus (ring_id = 0)
91+
collective_helper._init_communicator(
92+
self.startup_program, self.current_endpoint, self.global_endpoints,
93+
self.global_rank, self.global_ring_id, True, self.global_ring_id,
94+
True)
95+
self._broadcast_params(self.global_ring_id)
96+
97+
def minimize_impl(self,
98+
loss,
99+
startup_program=None,
100+
parameter_list=None,
101+
no_grad_set=None):
102+
self.endpoints = self.role_maker._get_trainer_endpoints()
103+
self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
104+
self.rank = self.role_maker._worker_index()
105+
self.nranks = self.role_maker._worker_num()
106+
if startup_program is None:
107+
startup_program = fluid.default_startup_program()
108+
self.startup_program = startup_program
109+
110+
block = loss.block
111+
program = block.program
112+
self.main_program = program
113+
114+
optimize_ops, params_grads = self.inner_opt.minimize(
115+
loss, startup_program, parameter_list, no_grad_set)
116+
if self.nranks == 1:
117+
return optimize_ops, params_grads
118+
self._init_process_group()
119+
120+
self.main_program = program
121+
if self.nranks > 1:
122+
self._transpile_main_program(loss)
123+
return optimize_ops, params_grads
124+
125+
def _transpile_main_program(self, loss):
126+
self._insert_loss_grad_ops(loss)
127+
self._insert_allreduce_ops()
128+
129+
def _insert_loss_grad_ops(self, loss):
130+
"""
131+
In order to keep the learning rate consistent in different numbers of
132+
training workers, we scale the loss grad by the number of workers
133+
"""
134+
block = self.main_program.global_block()
135+
for idx, op in reversed(list(enumerate(block.ops))):
136+
if is_loss_grad_op(op):
137+
loss_grad_var = block.vars[op.output_arg_names[0]]
138+
block._insert_op(
139+
idx + 1,
140+
type='scale',
141+
inputs={'X': loss_grad_var},
142+
outputs={'Out': loss_grad_var},
143+
attrs={
144+
'scale': 1.0 / self.nranks,
145+
OP_ROLE_KEY: OpRole.Backward
146+
})
147+
148+
def _insert_allreduce_ops(self):
149+
block = self.main_program.global_block()
150+
ring_id = self.global_ring_id
151+
grad = None
152+
for idx, op in reversed(list(enumerate(block.ops))):
153+
if is_backward_op(op) and \
154+
OP_ROLE_VAR_KEY in op.attr_names:
155+
op_role_var = op.attr(OP_ROLE_VAR_KEY)
156+
if len(op_role_var) == 0:
157+
continue
158+
assert len(op_role_var) % 2 == 0
159+
offset = 1
160+
for i in range(0, len(op_role_var), 2):
161+
param_name = op_role_var[i]
162+
param = block.var(param_name)
163+
grad_name = op_role_var[i + 1]
164+
grad = block.var(grad_name)
165+
if param.is_distributed:
166+
continue
167+
168+
block._insert_op(
169+
idx + offset,
170+
type='c_sync_calc_stream',
171+
inputs={'X': grad},
172+
outputs={'Out': grad},
173+
attrs={OP_ROLE_KEY: OpRole.Backward, })
174+
offset += 1
175+
block._insert_op(
176+
idx + offset,
177+
type='c_allreduce_sum',
178+
inputs={'X': grad},
179+
outputs={'Out': grad},
180+
attrs={
181+
'ring_id': ring_id,
182+
OP_ROLE_KEY: OpRole.Backward
183+
})
184+
185+
if grad is None:
186+
return
187+
188+
for idx, op in enumerate(block.ops):
189+
if is_optimizer_op(op):
190+
block._insert_op(
191+
idx,
192+
type='c_sync_comm_stream',
193+
inputs={'X': grad},
194+
outputs={'Out': grad},
195+
attrs={'ring_id': ring_id,
196+
OP_ROLE_KEY: OpRole.Backward})
197+
break

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
1717
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
1818
list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
1919
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
20+
list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
2021
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
2122
list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
2223
list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables)
@@ -54,6 +55,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2)
5455
list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3)
5556
list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
5657
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
58+
list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
5759
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
5860
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
5961
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
@@ -100,6 +102,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
100102
LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api)
101103
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
102104
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
105+
LIST(REMOVE_ITEM TEST_OPS test_raw_program_optimizer)
103106
endif()
104107

105108
if(WIN32)
@@ -571,7 +574,7 @@ endif()
571574
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf)
572575
# Coverage pipeline use cuda 10.1 now, profiler will random hang in cuda 10.1,
573576
# see https://github.com/PaddlePaddle/Paddle/issues/29082 for details.
574-
# We guess there are some bugs in cuda 10.1 or 10.2,
577+
# We guess there are some bugs in cuda 10.1 or 10.2,
575578
# since this unittest is stable in cuda 11 (py3 pipeline) now.
576579
if(NOT WITH_COVERAGE)
577580
py_test_modules(test_parallel_executor_profiler MODULES test_parallel_executor_profiler)
@@ -596,8 +599,8 @@ py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_c
596599
py_test_modules(test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000)
597600

598601
# NOTE: These unittests will appear NaN steadily in windows CI. After analysis,
599-
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on,
600-
# which will not appear in other CIs. The calculation behavior of some ops in inference mode is
602+
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on,
603+
# which will not appear in other CIs. The calculation behavior of some ops in inference mode is
601604
# inconsistent with that in non-inference mode.
602605
if(NOT ON_INFER)
603606
py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_executor_seresnext_base_cpu)
@@ -640,7 +643,7 @@ if (WITH_XPU)
640643
add_subdirectory(xpu)
641644
endif()
642645

643-
# dist xpu tests:
646+
# dist xpu tests:
644647
if (WITH_XPU_BKCL)
645648
py_test(test_collective_reduce_api_xpu SRCS "test_collective_reduce_api.py")
646649
py_test(test_collective_allreduce_api_xpu SRCS "test_collective_allreduce_api.py")
@@ -708,6 +711,7 @@ if (WITH_DISTRIBUTE)
708711
set_tests_properties(test_dist_fleet_ctr2 PROPERTIES TIMEOUT 200)
709712
set_tests_properties(test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200)
710713
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
714+
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
711715
endif()
712716

713717
if (WITH_DISTRIBUTE AND NOT APPLE)

0 commit comments

Comments
 (0)