20
20
import contextlib
21
21
import io
22
22
import unique_name
23
+ import parallel_executor
23
24
24
25
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
25
26
import optimizer as opt_module
@@ -48,12 +49,14 @@ class BeginStepEvent(object):
48
49
def __init__ (self , epoch_id , step_id ):
49
50
self .epoch = epoch_id
50
51
self .step = step_id
52
+ self .fetch_metrics = True
51
53
52
54
53
55
class EndStepEvent (object ):
54
- def __init__ (self , epoch_id , step_id ):
56
+ def __init__ (self , epoch_id , step_id , metrics ):
55
57
self .epoch = epoch_id
56
58
self .step = step_id
59
+ self .metrics = metrics
57
60
58
61
59
62
def check_and_get_place (place ):
@@ -87,12 +90,17 @@ class Trainer(object):
87
90
88
91
Args:
89
92
train_func(callable): A function which will return loss. The loss must be a scalar.
90
- infer_func(callable): A function which will return predict, used to save inference model
91
93
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
92
94
place: The device place of this trainer.
93
95
"""
94
96
95
- def __init__ (self , train_func , optimizer , param_path = None , place = None ):
97
+ def __init__ (self ,
98
+ train_func ,
99
+ optimizer ,
100
+ param_path = None ,
101
+ place = None ,
102
+ parallel = False ):
103
+ self .parallel = parallel
96
104
# 1. we need to generate a framework.Program by calling
97
105
# program_func. Reference: fluid.program_guard in
98
106
# test_word2vec.py
@@ -106,14 +114,14 @@ def __init__(self, train_func, optimizer, param_path=None, place=None):
106
114
107
115
with framework .program_guard (self .train_program , self .startup_program ):
108
116
program_func_outs = train_func ()
109
- self .test_outputs = program_func_outs if isinstance (
117
+ self .train_func_outputs = program_func_outs if isinstance (
110
118
program_func_outs , list ) else [program_func_outs ]
111
119
self .test_program = self .train_program .clone ()
112
120
if not isinstance (optimizer , opt_module .Optimizer ):
113
121
raise TypeError (
114
122
"The optimizer should be an instance of Optimizer" )
115
123
# The fisrt element of program_func_outs is loss.
116
- loss = self .test_outputs [0 ]
124
+ loss = self .train_func_outputs [0 ]
117
125
optimize_ops , params_grads = optimizer .minimize (loss )
118
126
119
127
self .place = check_and_get_place (place )
@@ -202,38 +210,32 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
202
210
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
203
211
)
204
212
205
- def train (self ,
206
- num_epochs ,
207
- event_handler ,
208
- reader ,
209
- feed_order ,
210
- parallel = False ):
213
+ def train (self , num_epochs , event_handler , reader = None , feed_order = None ):
211
214
"""
212
215
Train the model.
213
216
214
217
Args:
215
218
num_epochs: The number of epoch. An epoch will process all data in reader
216
219
event_handler: The event handler. A function with type (ev:Event)->void
217
220
reader:
218
- parallel: True if use multi-CPUs or multi-GPUs
219
221
feed_order: Feeding order of reader. None will following the defining
220
222
order in program
221
223
222
224
Returns:
223
225
224
226
"""
225
- if parallel :
226
- raise NotImplementedError (
227
- "Parallel Executor version of trainer is not implemented" )
228
-
229
227
training_role = os .getenv ("PADDLE_TRAINING_ROLE" , "" )
230
228
if training_role == "PSERVER" :
231
229
with self ._prog_and_scope_guard ():
232
230
exe = executor .Executor (self .place )
233
231
exe .run ()
234
232
return
235
-
236
- self ._train_by_executor (num_epochs , event_handler , reader , feed_order )
233
+ if self .parallel :
234
+ self ._train_by_parallel_executor (num_epochs , event_handler , reader ,
235
+ feed_order )
236
+ else :
237
+ self ._train_by_executor (num_epochs , event_handler , reader ,
238
+ feed_order )
237
239
238
240
def test (self , reader , feed_order ):
239
241
"""
@@ -245,7 +247,8 @@ def test(self, reader, feed_order):
245
247
order in program
246
248
"""
247
249
248
- return self ._test_by_executor (reader , feed_order , self .test_outputs )
250
+ return self ._test_by_executor (reader , feed_order ,
251
+ self .train_func_outputs )
249
252
250
253
def save_params (self , param_path ):
251
254
# reference: save_persistables in io.py
@@ -279,13 +282,25 @@ def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
279
282
feeder = data_feeder .DataFeeder (
280
283
feed_list = feed_var_list , place = self .place )
281
284
exe = executor .Executor (self .place )
282
- for epoch_id in range (num_epochs ):
283
- event_handler (BeginEpochEvent (epoch_id ))
284
- for step_id , data in enumerate (reader ()):
285
- event_handler (BeginStepEvent (epoch_id , step_id ))
286
- exe .run (feed = feeder .feed (data ), fetch_list = [])
287
- event_handler (EndStepEvent (epoch_id , step_id ))
288
- event_handler (EndEpochEvent (epoch_id ))
285
+ reader = feeder .decorate_reader (reader , multi_devices = False )
286
+ self ._train_by_any_executor (event_handler , exe , num_epochs , reader )
287
+
288
+ def _train_by_any_executor (self , event_handler , exe , num_epochs , reader ):
289
+ for epoch_id in range (num_epochs ):
290
+ event_handler (BeginEpochEvent (epoch_id ))
291
+ for step_id , data in enumerate (reader ()):
292
+ begin_event = BeginStepEvent (epoch_id , step_id )
293
+ event_handler (begin_event )
294
+ if begin_event .fetch_metrics :
295
+ metrics = exe .run (feed = data ,
296
+ fetch_list = [
297
+ var .name
298
+ for var in self .train_func_outputs
299
+ ])
300
+ else :
301
+ metrics = exe .run (feed = data , fetch_list = [])
302
+ event_handler (EndStepEvent (epoch_id , step_id , metrics ))
303
+ event_handler (EndEpochEvent (epoch_id ))
289
304
290
305
def _test_by_executor (self , reader , feed_order , fetch_list ):
291
306
with executor .scope_guard (self .scope ):
@@ -304,6 +319,28 @@ def _test_by_executor(self, reader, feed_order, fetch_list):
304
319
305
320
return [x / count for x in accumulated ]
306
321
322
+ def _train_by_parallel_executor (self , num_epochs , event_handler , reader ,
323
+ feed_order ):
324
+ with self ._prog_and_scope_guard ():
325
+ pe = self ._get_or_create_parallel_executor ()
326
+ feed_var_list = build_feed_var_list (self .train_program , feed_order )
327
+ feeder = data_feeder .DataFeeder (
328
+ feed_list = feed_var_list , place = self .place )
329
+ reader = feeder .decorate_reader (reader , multi_devices = True )
330
+ for epoch_id in range (num_epochs ):
331
+ self ._train_by_any_executor (event_handler , pe , num_epochs ,
332
+ reader )
333
+
334
+ def _get_parallel_executor (self ):
335
+ return getattr (self , 'parallel_executor' , None )
336
+
337
+ def _get_or_create_parallel_executor (self ):
338
+ if self ._get_parallel_executor () is None :
339
+ self .parallel_executor = parallel_executor .ParallelExecutor (
340
+ use_cuda = isinstance (self .place , core .CUDAPlace ),
341
+ loss_name = self .train_func_outputs [0 ].name )
342
+ return self ._get_parallel_executor ()
343
+
307
344
308
345
def build_feed_var_list (program , feed_order ):
309
346
if not isinstance (program , framework .Program ):
0 commit comments