12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import os
15
16
import core
16
17
import framework
17
18
import executor
20
21
21
22
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
22
23
import optimizer as opt_module
24
+ import distribute_transpiler
23
25
24
26
__all__ = [
25
27
'Trainer' ,
@@ -76,22 +78,61 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
76
78
raise TypeError (
77
79
"The optimizer should be an instance of Optimizer" )
78
80
79
- optimizer .minimize (loss )
81
+ optimize_ops , params_grads = optimizer .minimize (loss )
80
82
81
83
self .place = Trainer ._check_and_get_place (place )
82
84
85
+ self .dist_transpile_if_necessary (optimize_ops , params_grads )
86
+
83
87
# 2. move the default_main_program to self.program and run the
84
88
# default_startup program on an empty core.Scope()
85
89
# Run startup program
86
- exe = executor .Executor (place )
87
- exe .run (self .startup_program , scope = self .scope )
90
+ with self ._prog_and_scope_guard ():
91
+ exe = executor .Executor (place )
92
+ exe .run (self .startup_program )
88
93
89
94
if param_path :
90
95
# load params from param_path into scope
91
96
# TODO(yuyang): This depends on parameters implementation.
92
97
pass
93
98
94
- # TODO(helin): support distributed training
99
+ def dist_transpile_if_necessary (self , optimize_ops , params_grads ):
100
+ if "PADDLE_TRAINING_ROLE" not in os .environ :
101
+ return
102
+
103
+ # the port of all pservers, needed by both trainer and pserver
104
+ port = os .getenv ("PADDLE_PSERVER_PORT" , "6174" )
105
+ # comma separated ips of all pservers, needed by trainer and
106
+ # pserver
107
+ pserver_ips = os .getenv ("PADDLE_PSERVER_IPS" , "" )
108
+ eplist = []
109
+ for ip in pserver_ips .split ("," ):
110
+ eplist .append (':' .join ([ip , port ]))
111
+ pserver_endpoints = "," .join (eplist )
112
+ # total number of workers/trainers in the job, needed by
113
+ # trainer and pserver
114
+ trainers = int (os .getenv ("PADDLE_TRAINERS" ))
115
+ # the IP of the local machine, needed by pserver only
116
+ current_endpoint = os .getenv ("PADDLE_CURRENT_IP" , "" ) + ":" + port
117
+ # the unique trainer id, starting from 0, needed by trainer
118
+ # only
119
+ trainer_id = int (os .getenv ("PADDLE_TRAINER_ID" , "0" ))
120
+ # the role, should be either PSERVER or TRAINER
121
+ training_role = os .getenv ("PADDLE_TRAINING_ROLE" )
122
+ with self ._prog_and_scope_guard ():
123
+ t = distribute_transpiler .DistributeTranspiler ()
124
+ t .transpile (
125
+ trainer_id , pservers = pserver_endpoints , trainers = trainers )
126
+ if training_role == "PSERVER" :
127
+ self .train_program = t .get_pserver_program (current_endpoint )
128
+ self .startup_program = t .get_startup_program (current_endpoint ,
129
+ self .train_program )
130
+ elif training_role == "TRAINER" :
131
+ self .train_program = t .get_trainer_program ()
132
+ else :
133
+ raise ValueError (
134
+ 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
135
+ )
95
136
96
137
def train (self ,
97
138
num_epochs ,
@@ -117,6 +158,13 @@ def train(self,
117
158
raise NotImplementedError (
118
159
"Parallel Executor version of trainer is not implemented" )
119
160
161
+ training_role = os .getenv ("PADDLE_TRAINING_ROLE" , "" )
162
+ if training_role == "PSERVER" :
163
+ with self ._prog_and_scope_guard ():
164
+ exe = executor .Executor (self .place )
165
+ exe .run ()
166
+ return
167
+
120
168
self ._train_by_executor (num_epochs , event_handler , reader , feed_order )
121
169
122
170
def test (self , reader ):
0 commit comments