Skip to content

Commit ec773f9

Browse files
author
yi.wu
committed
fix ut merge error
1 parent 1b79974 commit ec773f9

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import six
2222
import signal
2323
import subprocess
24-
import six
24+
import argparse
2525

2626

2727
class TestDistRunnerBase(object):
@@ -30,7 +30,7 @@ def get_model(self, batch_size=2):
3030
"get_model should be implemented by child classes.")
3131

3232
def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
33-
trainers):
33+
trainers, sync_mode):
3434
# NOTE: import fluid until runtime, or else forking processes will cause error.
3535
import paddle
3636
import paddle.fluid as fluid
@@ -39,33 +39,35 @@ def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
3939
trainer_id=trainer_id,
4040
program=main_program,
4141
pservers=pserver_endpoints,
42-
trainers=trainers)
42+
trainers=trainers,
43+
sync_mode=sync_mode)
4344
return t
4445

45-
def run_pserver(self, pserver_endpoints, trainers, current_endpoint,
46-
trainer_id):
46+
def run_pserver(self, args):
4747
import paddle
4848
import paddle.fluid as fluid
4949
self.get_model(batch_size=2)
50-
t = self.get_transpiler(trainer_id,
51-
fluid.default_main_program(), pserver_endpoints,
52-
trainers)
53-
pserver_prog = t.get_pserver_program(current_endpoint)
54-
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
50+
t = self.get_transpiler(args.trainer_id,
51+
fluid.default_main_program(), args.endpoints,
52+
args.trainers, args.sync_mode)
53+
pserver_prog = t.get_pserver_program(args.current_endpoint)
54+
startup_prog = t.get_startup_program(args.current_endpoint,
55+
pserver_prog)
5556
place = fluid.CPUPlace()
5657
exe = fluid.Executor(place)
5758
exe.run(startup_prog)
5859
exe.run(pserver_prog)
5960

60-
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
61+
def run_trainer(self, place, args):
6162
import paddle
6263
import paddle.fluid as fluid
6364
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
6465
self.get_model(batch_size=2)
65-
if is_dist:
66-
t = self.get_transpiler(trainer_id,
67-
fluid.default_main_program(), endpoints,
68-
trainers)
66+
if args.is_dist:
67+
t = self.get_transpiler(args.trainer_id,
68+
fluid.default_main_program(),
69+
args.endpoints, args.trainers,
70+
args.sync_mode)
6971
trainer_prog = t.get_trainer_program()
7072
else:
7173
trainer_prog = fluid.default_main_program()
@@ -132,18 +134,21 @@ def runtime_main(test_class):
132134
args = parser.parse_args()
133135

134136
model = test_class()
135-
if role == "pserver":
136-
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id)
137+
if args.role == "pserver":
138+
model.run_pserver(args)
137139
else:
138140
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
139141
) else fluid.CPUPlace()
140-
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
142+
model.run_trainer(p, args)
141143

142144

143145
import paddle.compat as cpt
144146

145147

146148
class TestDistBase(unittest.TestCase):
149+
def _setup_config(self):
150+
raise NotImplementedError("tests should have _setup_config implemented")
151+
147152
def setUp(self):
148153
self._trainers = 2
149154
self._pservers = 2
@@ -221,9 +226,7 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
221226
# Run local to get a base line
222227
env_local = {"CUDA_VISIBLE_DEVICES": "0"}
223228
env_local.update(required_envs)
224-
local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \
225-
(self._python_interp, model_file,
226-
"127.0.0.1:1234", "127.0.0.1:1234", 1)
229+
local_cmd = "%s %s --role trainer" % (self._python_interp, model_file)
227230
if not check_error_log:
228231
local_proc = subprocess.Popen(
229232
local_cmd.split(" "),

0 commit comments

Comments
 (0)