Skip to content

Commit 03b45b6

Browse files
committed
fix dist train reduce mode (#13068)
* fix dist train reduce mode * fix previous fix
1 parent f517d01 commit 03b45b6

File tree

4 files changed

+122
-27
lines changed

4 files changed

+122
-27
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
744744
.emplace(varname, op_dev_id);
745745
}
746746
} else {
747-
PADDLE_ENFORCE(
747+
PADDLE_THROW(
748748
"the distribute training related op should be in [split_byref, "
749749
"concat].");
750750
}

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,18 @@ def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
7676
strategy = fluid.ExecutionStrategy()
7777
strategy.num_threads = 1
7878
strategy.allow_op_delay = False
79+
build_stra = fluid.BuildStrategy()
80+
81+
if args.use_reduce:
82+
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
83+
else:
84+
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
85+
7986
exe = fluid.ParallelExecutor(
80-
True, loss_name=avg_cost.name, exec_strategy=strategy)
87+
True,
88+
loss_name=avg_cost.name,
89+
exec_strategy=strategy,
90+
build_strategy=build_stra)
8191

8292
feed_var_list = [
8393
var for var in trainer_prog.global_block().vars.values()
@@ -106,16 +116,20 @@ def runtime_main(test_class):
106116
import paddle.fluid as fluid
107117
import paddle.fluid.core as core
108118

109-
if len(sys.argv) != 7:
110-
print(
111-
"Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]"
112-
)
113-
role = sys.argv[1]
114-
endpoints = sys.argv[2]
115-
trainer_id = int(sys.argv[3])
116-
current_endpoint = sys.argv[4]
117-
trainers = int(sys.argv[5])
118-
is_dist = True if sys.argv[6] == "TRUE" else False
119+
parser = argparse.ArgumentParser(description='Run dist test.')
120+
parser.add_argument(
121+
'--role', type=str, required=True, choices=['pserver', 'trainer'])
122+
parser.add_argument('--endpoints', type=str, required=False, default="")
123+
parser.add_argument('--is_dist', action='store_true')
124+
parser.add_argument('--trainer_id', type=int, required=False, default=0)
125+
parser.add_argument('--trainers', type=int, required=False, default=1)
126+
parser.add_argument(
127+
'--current_endpoint', type=str, required=False, default="")
128+
parser.add_argument('--sync_mode', action='store_true')
129+
parser.add_argument('--mem_opt', action='store_true')
130+
parser.add_argument('--use_reduce', action='store_true')
131+
132+
args = parser.parse_args()
119133

120134
model = test_class()
121135
if role == "pserver":
@@ -135,16 +149,28 @@ def setUp(self):
135149
self._pservers = 2
136150
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
137151
self._python_interp = "python"
152+
self._sync_mode = True
153+
self._mem_opt = False
154+
self._use_reduce = False
155+
self._setup_config()
138156

139157
def start_pserver(self, model_file, check_error_log):
140158
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
141-
ps0_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
159+
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
160+
ps0_cmd = ps_cmd % \
142161
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
143162
self._trainers)
144-
ps1_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
163+
ps1_cmd = ps_cmd % \
145164
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
146165
self._trainers)
147166

167+
if self._sync_mode:
168+
ps0_cmd += " --sync_mode"
169+
ps1_cmd += " --sync_mode"
170+
if self._mem_opt:
171+
ps0_cmd += " --mem_opt"
172+
ps1_cmd += " --mem_opt"
173+
148174
ps0_pipe = subprocess.PIPE
149175
ps1_pipe = subprocess.PIPE
150176
if check_error_log:
@@ -226,12 +252,23 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
226252
self._wait_ps_ready(ps1.pid)
227253

228254
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
229-
tr0_cmd = "%s %s trainer %s 0 %s %d TRUE" % \
230-
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
231-
self._trainers)
232-
tr1_cmd = "%s %s trainer %s 1 %s %d TRUE" % \
233-
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
234-
self._trainers)
255+
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
256+
tr0_cmd = tr_cmd % \
257+
(self._python_interp, model_file, self._ps_endpoints,
258+
0, ps0_ep, self._trainers)
259+
tr1_cmd = tr_cmd % \
260+
(self._python_interp, model_file, self._ps_endpoints,
261+
1, ps1_ep, self._trainers)
262+
263+
if self._sync_mode:
264+
tr0_cmd += " --sync_mode"
265+
tr1_cmd += " --sync_mode"
266+
if self._mem_opt:
267+
tr0_cmd += " --mem_opt"
268+
tr1_cmd += " --mem_opt"
269+
if self._use_reduce:
270+
tr0_cmd += " --use_reduce"
271+
tr1_cmd += " --use_reduce"
235272

236273
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
237274
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
@@ -282,6 +319,10 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
282319
# FIXME: use terminate() instead of sigkill.
283320
os.kill(ps0.pid, signal.SIGKILL)
284321
os.kill(ps1.pid, signal.SIGKILL)
322+
ps0.terminate()
323+
ps1.terminate()
324+
ps0.wait()
325+
ps1.wait()
285326
FNULL.close()
286327

287328
self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta)

python/paddle/fluid/tests/unittests/test_dist_mnist.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,51 @@
1717
from test_dist_base import TestDistBase
1818

1919

20-
class TestDistSeResneXt2x2(TestDistBase):
20+
class TestDistMnist2x2(TestDistBase):
21+
def _setup_config(self):
22+
self._sync_mode = True
23+
self._use_reduce = False
24+
25+
def test_se_resnext(self):
26+
self.check_with_place("dist_mnist.py", delta=1e-7)
27+
28+
29+
class TestDistMnist2x2WithMemopt(TestDistBase):
30+
def _setup_config(self):
31+
self._sync_mode = True
32+
self._mem_opt = True
33+
2134
def test_se_resnext(self):
2235
self.check_with_place("dist_mnist.py", delta=1e-7)
2336

2437

38+
class TestDistMnistAsync(TestDistBase):
39+
def _setup_config(self):
40+
self._sync_mode = False
41+
self._use_reduce = False
42+
43+
def test_se_resnext(self):
44+
self.check_with_place("dist_mnist.py", delta=200)
45+
46+
47+
# FIXME(typhoonzero): enable these tests once we have 4
48+
# 4 GPUs on CI machine, and the base class should be updated.
49+
#
50+
# class TestDistMnist2x2ReduceMode(TestDistBase):
51+
# def _setup_config(self):
52+
# self._sync_mode = True
53+
# self._use_reduce = True
54+
55+
# def test_se_resnext(self):
56+
# self.check_with_place("dist_mnist.py", delta=1e-7)
57+
58+
# class TestDistMnistAsyncReduceMode(TestDistBase):
59+
# def _setup_config(self):
60+
# self._sync_mode = False
61+
# self._use_reduce = True
62+
63+
# def test_se_resnext(self):
64+
# self.check_with_place("dist_mnist.py", delta=200)
65+
2566
if __name__ == "__main__":
2667
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ def transpile(self,
265265
name=framework.generate_control_dev_var_name())
266266
grad_name_to_send_dummy_out[grad_varname] = dummy_output
267267

268+
# get send op_role_var, if not splited, the grad should have .trainer suffix
269+
# if splited, grad should be the original grad var name (split_by_ref and send
270+
# will be on the same place). ParallelExecutor
271+
# will use op_role_var to get expected device place to run this op.
268272
program.global_block()._insert_op(
269273
index=index + 1,
270274
type="send",
@@ -273,8 +277,10 @@ def transpile(self,
273277
attrs={
274278
"epmap": eplist,
275279
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
276-
OP_ROLE_VAR_ATTR_NAME:
277-
[self.grad_name_to_param_name[grad_varname], grad_varname],
280+
OP_ROLE_VAR_ATTR_NAME: [
281+
self.grad_name_to_param_name[grad_varname],
282+
splited_grad_varname
283+
],
278284
"sync_mode": not self.sync_mode,
279285
})
280286
for _, var in enumerate(splited_vars):
@@ -318,17 +324,24 @@ def transpile(self,
318324
recv_dep_in = grad_name_to_send_dummy_out[
319325
self.param_name_to_grad_name[param_varname]]
320326
all_recv_outputs.extend(splited_var)
327+
# get recv op_role_var, if not splited, the grad should have .trainer suffix
328+
# if splited, grad should be the original grad var name. ParallelExecutor
329+
# will use op_role_var to get expected device place to run this op.
330+
orig_grad_name = self.param_name_to_grad_name[param_varname]
331+
recv_op_role_var_name = orig_grad_name
332+
splited_trainer_grad = self.grad_var_mapping[orig_grad_name]
333+
if len(splited_trainer_grad) == 1:
334+
recv_op_role_var_name = splited_trainer_grad[0].name
335+
321336
program.global_block().append_op(
322337
type="recv",
323338
inputs={"X": [recv_dep_in]},
324339
outputs={"Out": splited_var},
325340
attrs={
326341
"epmap": eps,
327342
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
328-
OP_ROLE_VAR_ATTR_NAME: [
329-
param_varname,
330-
self.param_name_to_grad_name[param_varname]
331-
],
343+
OP_ROLE_VAR_ATTR_NAME:
344+
[param_varname, recv_op_role_var_name],
332345
"sync_mode": not self.sync_mode
333346
})
334347

0 commit comments

Comments
 (0)