21
21
import six
22
22
import signal
23
23
import subprocess
24
- import six
24
+ import argparse
25
25
26
26
27
27
class TestDistRunnerBase (object ):
@@ -30,7 +30,7 @@ def get_model(self, batch_size=2):
30
30
"get_model should be implemented by child classes." )
31
31
32
32
def get_transpiler (self , trainer_id , main_program , pserver_endpoints ,
33
- trainers ):
33
+ trainers , sync_mode ):
34
34
# NOTE: import fluid until runtime, or else forking processes will cause error.
35
35
import paddle
36
36
import paddle .fluid as fluid
@@ -39,33 +39,35 @@ def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
39
39
trainer_id = trainer_id ,
40
40
program = main_program ,
41
41
pservers = pserver_endpoints ,
42
- trainers = trainers )
42
+ trainers = trainers ,
43
+ sync_mode = sync_mode )
43
44
return t
44
45
45
- def run_pserver (self , pserver_endpoints , trainers , current_endpoint ,
46
- trainer_id ):
46
+ def run_pserver (self , args ):
47
47
import paddle
48
48
import paddle .fluid as fluid
49
49
self .get_model (batch_size = 2 )
50
- t = self .get_transpiler (trainer_id ,
51
- fluid .default_main_program (), pserver_endpoints ,
52
- trainers )
53
- pserver_prog = t .get_pserver_program (current_endpoint )
54
- startup_prog = t .get_startup_program (current_endpoint , pserver_prog )
50
+ t = self .get_transpiler (args .trainer_id ,
51
+ fluid .default_main_program (), args .endpoints ,
52
+ args .trainers , args .sync_mode )
53
+ pserver_prog = t .get_pserver_program (args .current_endpoint )
54
+ startup_prog = t .get_startup_program (args .current_endpoint ,
55
+ pserver_prog )
55
56
place = fluid .CPUPlace ()
56
57
exe = fluid .Executor (place )
57
58
exe .run (startup_prog )
58
59
exe .run (pserver_prog )
59
60
60
- def run_trainer (self , place , endpoints , trainer_id , trainers , is_dist = True ):
61
+ def run_trainer (self , place , args ):
61
62
import paddle
62
63
import paddle .fluid as fluid
63
64
test_program , avg_cost , train_reader , test_reader , batch_acc , predict = \
64
65
self .get_model (batch_size = 2 )
65
- if is_dist :
66
- t = self .get_transpiler (trainer_id ,
67
- fluid .default_main_program (), endpoints ,
68
- trainers )
66
+ if args .is_dist :
67
+ t = self .get_transpiler (args .trainer_id ,
68
+ fluid .default_main_program (),
69
+ args .endpoints , args .trainers ,
70
+ args .sync_mode )
69
71
trainer_prog = t .get_trainer_program ()
70
72
else :
71
73
trainer_prog = fluid .default_main_program ()
@@ -132,18 +134,21 @@ def runtime_main(test_class):
132
134
args = parser .parse_args ()
133
135
134
136
model = test_class ()
135
- if role == "pserver" :
136
- model .run_pserver (endpoints , trainers , current_endpoint , trainer_id )
137
+ if args . role == "pserver" :
138
+ model .run_pserver (args )
137
139
else :
138
140
p = fluid .CUDAPlace (0 ) if core .is_compiled_with_cuda (
139
141
) else fluid .CPUPlace ()
140
- model .run_trainer (p , endpoints , trainer_id , trainers , is_dist )
142
+ model .run_trainer (p , args )
141
143
142
144
143
145
import paddle .compat as cpt
144
146
145
147
146
148
class TestDistBase (unittest .TestCase ):
149
+ def _setup_config (self ):
150
+ raise NotImplementedError ("tests should have _setup_config implemented" )
151
+
147
152
def setUp (self ):
148
153
self ._trainers = 2
149
154
self ._pservers = 2
@@ -221,9 +226,7 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
221
226
# Run local to get a base line
222
227
env_local = {"CUDA_VISIBLE_DEVICES" : "0" }
223
228
env_local .update (required_envs )
224
- local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \
225
- (self ._python_interp , model_file ,
226
- "127.0.0.1:1234" , "127.0.0.1:1234" , 1 )
229
+ local_cmd = "%s %s --role trainer" % (self ._python_interp , model_file )
227
230
if not check_error_log :
228
231
local_proc = subprocess .Popen (
229
232
local_cmd .split (" " ),
0 commit comments