Skip to content

Commit 4e86c89

Browse files
authored
Merge pull request #10620 from reyoung/feature/trainer_by_pe
Draft for train by parallel executor
2 parents dbbeccc + f047548 commit 4e86c89

File tree

2 files changed

+71
-28
lines changed

2 files changed

+71
-28
lines changed

python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def train(use_cuda, train_program, save_dirname):
6262
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
6363

6464
trainer = fluid.Trainer(
65-
train_func=train_program, place=place, optimizer=optimizer)
65+
train_func=train_program,
66+
place=place,
67+
optimizer=optimizer,
68+
parallel=True)
6669

6770
def event_handler(event):
6871
if isinstance(event, fluid.EndEpochEvent):
@@ -87,6 +90,9 @@ def event_handler(event):
8790
event.epoch + 1, float(avg_cost), float(acc)))
8891
if math.isnan(float(avg_cost)):
8992
sys.exit("got NaN loss, training failed.")
93+
elif isinstance(event, fluid.EndStepEvent):
94+
print("Step {0}, Epoch {1} Metrics {2}".format(
95+
event.step, event.epoch, map(numpy.array, event.metrics)))
9096

9197
train_reader = paddle.batch(
9298
paddle.reader.shuffle(
@@ -131,4 +137,4 @@ def main(use_cuda):
131137

132138
if __name__ == '__main__':
133139
# for use_cuda in (False, True):
134-
main(use_cuda=False)
140+
main(use_cuda=True)

python/paddle/fluid/trainer.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import contextlib
2121
import io
2222
import unique_name
23+
import parallel_executor
2324

2425
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
2526
import optimizer as opt_module
@@ -48,12 +49,14 @@ class BeginStepEvent(object):
4849
def __init__(self, epoch_id, step_id):
4950
self.epoch = epoch_id
5051
self.step = step_id
52+
self.fetch_metrics = True
5153

5254

5355
class EndStepEvent(object):
54-
def __init__(self, epoch_id, step_id):
56+
def __init__(self, epoch_id, step_id, metrics):
5557
self.epoch = epoch_id
5658
self.step = step_id
59+
self.metrics = metrics
5760

5861

5962
def check_and_get_place(place):
@@ -87,12 +90,17 @@ class Trainer(object):
8790
8891
Args:
8992
train_func(callable): A function which will return loss. The loss must be a scalar.
90-
infer_func(callable): A function which will return predict, used to save inference model
9193
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
9294
place: The device place of this trainer.
9395
"""
9496

95-
def __init__(self, train_func, optimizer, param_path=None, place=None):
97+
def __init__(self,
98+
train_func,
99+
optimizer,
100+
param_path=None,
101+
place=None,
102+
parallel=False):
103+
self.parallel = parallel
96104
# 1. we need to generate a framework.Program by calling
97105
# program_func. Reference: fluid.program_guard in
98106
# test_word2vec.py
@@ -106,14 +114,14 @@ def __init__(self, train_func, optimizer, param_path=None, place=None):
106114

107115
with framework.program_guard(self.train_program, self.startup_program):
108116
program_func_outs = train_func()
109-
self.test_outputs = program_func_outs if isinstance(
117+
self.train_func_outputs = program_func_outs if isinstance(
110118
program_func_outs, list) else [program_func_outs]
111119
self.test_program = self.train_program.clone()
112120
if not isinstance(optimizer, opt_module.Optimizer):
113121
raise TypeError(
114122
"The optimizer should be an instance of Optimizer")
115123
# The fisrt element of program_func_outs is loss.
116-
loss = self.test_outputs[0]
124+
loss = self.train_func_outputs[0]
117125
optimize_ops, params_grads = optimizer.minimize(loss)
118126

119127
self.place = check_and_get_place(place)
@@ -202,38 +210,32 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
202210
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
203211
)
204212

205-
def train(self,
206-
num_epochs,
207-
event_handler,
208-
reader,
209-
feed_order,
210-
parallel=False):
213+
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
211214
"""
212215
Train the model.
213216
214217
Args:
215218
num_epochs: The number of epoch. An epoch will process all data in reader
216219
event_handler: The event handler. A function with type (ev:Event)->void
217220
reader:
218-
parallel: True if use multi-CPUs or multi-GPUs
219221
feed_order: Feeding order of reader. None will following the defining
220222
order in program
221223
222224
Returns:
223225
224226
"""
225-
if parallel:
226-
raise NotImplementedError(
227-
"Parallel Executor version of trainer is not implemented")
228-
229227
training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
230228
if training_role == "PSERVER":
231229
with self._prog_and_scope_guard():
232230
exe = executor.Executor(self.place)
233231
exe.run()
234232
return
235-
236-
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
233+
if self.parallel:
234+
self._train_by_parallel_executor(num_epochs, event_handler, reader,
235+
feed_order)
236+
else:
237+
self._train_by_executor(num_epochs, event_handler, reader,
238+
feed_order)
237239

238240
def test(self, reader, feed_order):
239241
"""
@@ -245,7 +247,8 @@ def test(self, reader, feed_order):
245247
order in program
246248
"""
247249

248-
return self._test_by_executor(reader, feed_order, self.test_outputs)
250+
return self._test_by_executor(reader, feed_order,
251+
self.train_func_outputs)
249252

250253
def save_params(self, param_path):
251254
# reference: save_persistables in io.py
@@ -279,13 +282,25 @@ def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
279282
feeder = data_feeder.DataFeeder(
280283
feed_list=feed_var_list, place=self.place)
281284
exe = executor.Executor(self.place)
282-
for epoch_id in range(num_epochs):
283-
event_handler(BeginEpochEvent(epoch_id))
284-
for step_id, data in enumerate(reader()):
285-
event_handler(BeginStepEvent(epoch_id, step_id))
286-
exe.run(feed=feeder.feed(data), fetch_list=[])
287-
event_handler(EndStepEvent(epoch_id, step_id))
288-
event_handler(EndEpochEvent(epoch_id))
285+
reader = feeder.decorate_reader(reader, multi_devices=False)
286+
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
287+
288+
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
289+
for epoch_id in range(num_epochs):
290+
event_handler(BeginEpochEvent(epoch_id))
291+
for step_id, data in enumerate(reader()):
292+
begin_event = BeginStepEvent(epoch_id, step_id)
293+
event_handler(begin_event)
294+
if begin_event.fetch_metrics:
295+
metrics = exe.run(feed=data,
296+
fetch_list=[
297+
var.name
298+
for var in self.train_func_outputs
299+
])
300+
else:
301+
metrics = exe.run(feed=data, fetch_list=[])
302+
event_handler(EndStepEvent(epoch_id, step_id, metrics))
303+
event_handler(EndEpochEvent(epoch_id))
289304

290305
def _test_by_executor(self, reader, feed_order, fetch_list):
291306
with executor.scope_guard(self.scope):
@@ -304,6 +319,28 @@ def _test_by_executor(self, reader, feed_order, fetch_list):
304319

305320
return [x / count for x in accumulated]
306321

322+
def _train_by_parallel_executor(self, num_epochs, event_handler, reader,
323+
feed_order):
324+
with self._prog_and_scope_guard():
325+
pe = self._get_or_create_parallel_executor()
326+
feed_var_list = build_feed_var_list(self.train_program, feed_order)
327+
feeder = data_feeder.DataFeeder(
328+
feed_list=feed_var_list, place=self.place)
329+
reader = feeder.decorate_reader(reader, multi_devices=True)
330+
for epoch_id in range(num_epochs):
331+
self._train_by_any_executor(event_handler, pe, num_epochs,
332+
reader)
333+
334+
def _get_parallel_executor(self):
335+
return getattr(self, 'parallel_executor', None)
336+
337+
def _get_or_create_parallel_executor(self):
338+
if self._get_parallel_executor() is None:
339+
self.parallel_executor = parallel_executor.ParallelExecutor(
340+
use_cuda=isinstance(self.place, core.CUDAPlace),
341+
loss_name=self.train_func_outputs[0].name)
342+
return self._get_parallel_executor()
343+
307344

308345
def build_feed_var_list(program, feed_order):
309346
if not isinstance(program, framework.Program):

0 commit comments

Comments
 (0)