Skip to content

Commit 0b8067c

Browse files
authored
fix dist train reduce mode (#13068)
* fix dist train reduce mode * fix previous fix
1 parent 823c4f8 commit 0b8067c

File tree

4 files changed

+79
-21
lines changed

4 files changed

+79
-21
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
744744
.emplace(varname, op_dev_id);
745745
}
746746
} else {
747-
PADDLE_ENFORCE(
747+
PADDLE_THROW(
748748
"the distribute training related op should be in [split_byref, "
749749
"concat].");
750750
}

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,18 @@ def run_trainer(self, place, args):
8282
strategy = fluid.ExecutionStrategy()
8383
strategy.num_threads = 1
8484
strategy.allow_op_delay = False
85+
build_stra = fluid.BuildStrategy()
86+
87+
if args.use_reduce:
88+
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
89+
else:
90+
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
91+
8592
exe = fluid.ParallelExecutor(
86-
True, loss_name=avg_cost.name, exec_strategy=strategy)
93+
True,
94+
loss_name=avg_cost.name,
95+
exec_strategy=strategy,
96+
build_strategy=build_stra)
8797

8898
feed_var_list = [
8999
var for var in trainer_prog.global_block().vars.values()
@@ -123,6 +133,7 @@ def runtime_main(test_class):
123133
'--current_endpoint', type=str, required=False, default="")
124134
parser.add_argument('--sync_mode', action='store_true')
125135
parser.add_argument('--mem_opt', action='store_true')
136+
parser.add_argument('--use_reduce', action='store_true')
126137

127138
args = parser.parse_args()
128139

@@ -149,20 +160,25 @@ def setUp(self):
149160
self._python_interp = "python"
150161
self._sync_mode = True
151162
self._mem_opt = False
163+
self._use_reduce = False
152164
self._setup_config()
153165

154166
def start_pserver(self, model_file, check_error_log):
155-
156167
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
157-
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s"
158-
sync_mode_str = "--sync_mode" if self._sync_mode else ""
159-
mem_opt_str = "--mem_opt" if self._mem_opt else ""
168+
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
160169
ps0_cmd = ps_cmd % \
161170
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
162-
self._trainers, sync_mode_str, mem_opt_str)
171+
self._trainers)
163172
ps1_cmd = ps_cmd % \
164173
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
165-
self._trainers, sync_mode_str, mem_opt_str)
174+
self._trainers)
175+
176+
if self._sync_mode:
177+
ps0_cmd += " --sync_mode"
178+
ps1_cmd += " --sync_mode"
179+
if self._mem_opt:
180+
ps0_cmd += " --mem_opt"
181+
ps1_cmd += " --mem_opt"
166182

167183
ps0_pipe = subprocess.PIPE
168184
ps1_pipe = subprocess.PIPE
@@ -242,17 +258,23 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
242258
self._wait_ps_ready(ps1.pid)
243259

244260
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
245-
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s"
246-
sync_mode_str = "--sync_mode" if self._sync_mode else ""
247-
mem_opt_str = "--mem_opt" if self._mem_opt else ""
261+
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
248262
tr0_cmd = tr_cmd % \
249263
(self._python_interp, model_file, self._ps_endpoints,
250-
0, ps0_ep,
251-
self._trainers, sync_mode_str, mem_opt_str)
264+
0, ps0_ep, self._trainers)
252265
tr1_cmd = tr_cmd % \
253266
(self._python_interp, model_file, self._ps_endpoints,
254-
1, ps1_ep,
255-
self._trainers, sync_mode_str, mem_opt_str)
267+
1, ps1_ep, self._trainers)
268+
269+
if self._sync_mode:
270+
tr0_cmd += " --sync_mode"
271+
tr1_cmd += " --sync_mode"
272+
if self._mem_opt:
273+
tr0_cmd += " --mem_opt"
274+
tr1_cmd += " --mem_opt"
275+
if self._use_reduce:
276+
tr0_cmd += " --use_reduce"
277+
tr1_cmd += " --use_reduce"
256278

257279
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
258280
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
@@ -303,6 +325,8 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
303325
# FIXME: use terminate() instead of sigkill.
304326
os.kill(ps0.pid, signal.SIGKILL)
305327
os.kill(ps1.pid, signal.SIGKILL)
328+
ps0.terminate()
329+
ps1.terminate()
306330
ps0.wait()
307331
ps1.wait()
308332
FNULL.close()

python/paddle/fluid/tests/unittests/test_dist_mnist.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
class TestDistMnist2x2(TestDistBase):
2121
def _setup_config(self):
2222
self._sync_mode = True
23+
self._use_reduce = False
2324

2425
def test_se_resnext(self):
2526
self.check_with_place("dist_mnist.py", delta=1e-7)
@@ -37,10 +38,30 @@ def test_se_resnext(self):
3738
class TestDistMnistAsync(TestDistBase):
3839
def _setup_config(self):
3940
self._sync_mode = False
41+
self._use_reduce = False
4042

4143
def test_se_resnext(self):
4244
self.check_with_place("dist_mnist.py", delta=200)
4345

4446

47+
# FIXME(typhoonzero): enable these tests once we have 4
48+
# 4 GPUs on CI machine, and the base class should be updated.
49+
#
50+
# class TestDistMnist2x2ReduceMode(TestDistBase):
51+
# def _setup_config(self):
52+
# self._sync_mode = True
53+
# self._use_reduce = True
54+
55+
# def test_se_resnext(self):
56+
# self.check_with_place("dist_mnist.py", delta=1e-7)
57+
58+
# class TestDistMnistAsyncReduceMode(TestDistBase):
59+
# def _setup_config(self):
60+
# self._sync_mode = False
61+
# self._use_reduce = True
62+
63+
# def test_se_resnext(self):
64+
# self.check_with_place("dist_mnist.py", delta=200)
65+
4566
if __name__ == "__main__":
4667
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def transpile(self,
273273
name=framework.generate_control_dev_var_name())
274274
grad_name_to_send_dummy_out[grad_varname] = dummy_output
275275

276+
# get send op_role_var, if not splited, the grad should have .trainer suffix
277+
# if splited, grad should be the original grad var name (split_by_ref and send
278+
# will be on the same place). ParallelExecutor
279+
# will use op_role_var to get expected device place to run this op.
276280
program.global_block()._insert_op(
277281
index=index + 1,
278282
type="send",
@@ -281,8 +285,10 @@ def transpile(self,
281285
attrs={
282286
"epmap": eplist,
283287
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
284-
OP_ROLE_VAR_ATTR_NAME:
285-
[self.grad_name_to_param_name[grad_varname], grad_varname],
288+
OP_ROLE_VAR_ATTR_NAME: [
289+
self.grad_name_to_param_name[grad_varname],
290+
splited_grad_varname
291+
],
286292
"sync_mode": not self.sync_mode,
287293
})
288294
for _, var in enumerate(splited_vars):
@@ -326,17 +332,24 @@ def transpile(self,
326332
recv_dep_in = grad_name_to_send_dummy_out[
327333
self.param_name_to_grad_name[param_varname]]
328334
all_recv_outputs.extend(splited_var)
335+
# get recv op_role_var, if not splited, the grad should have .trainer suffix
336+
# if splited, grad should be the original grad var name. ParallelExecutor
337+
# will use op_role_var to get expected device place to run this op.
338+
orig_grad_name = self.param_name_to_grad_name[param_varname]
339+
recv_op_role_var_name = orig_grad_name
340+
splited_trainer_grad = self.grad_var_mapping[orig_grad_name]
341+
if len(splited_trainer_grad) == 1:
342+
recv_op_role_var_name = splited_trainer_grad[0].name
343+
329344
program.global_block().append_op(
330345
type="recv",
331346
inputs={"X": [recv_dep_in]},
332347
outputs={"Out": splited_var},
333348
attrs={
334349
"epmap": eps,
335350
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
336-
OP_ROLE_VAR_ATTR_NAME: [
337-
param_varname,
338-
self.param_name_to_grad_name[param_varname]
339-
],
351+
OP_ROLE_VAR_ATTR_NAME:
352+
[param_varname, recv_op_role_var_name],
340353
"sync_mode": not self.sync_mode
341354
})
342355

0 commit comments

Comments
 (0)