Skip to content

Commit eb7d875

Browse files
jacquesqiaodaming-lu
authored andcommitted
add trainer.stop and fix a bug for train_by_parallel_executor (#10762)
1 parent 54ae8e4 commit eb7d875

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,20 @@ def train(use_cuda, train_program, save_dirname):
5757
optimizer=fluid.optimizer.SGD(learning_rate=0.001))
5858

5959
def event_handler(event):
60-
if isinstance(event, fluid.EndEpochEvent):
61-
test_metrics = trainer.test(
62-
reader=test_reader, feed_order=['x', 'y'])
63-
print test_metrics
64-
'''
65-
66-
...
67-
['25.768919467926025']
68-
['15.343549569447836']
69-
...
70-
71-
'''
72-
if float(test_metrics[0]) < 20.0:
60+
if isinstance(event, fluid.EndStepEvent):
61+
if event.step == 10:
62+
test_metrics = trainer.test(
63+
reader=test_reader, feed_order=['x', 'y'])
64+
print test_metrics
65+
'''
66+
...
67+
['25.768919467926025']
68+
['15.343549569447836']
69+
...
70+
'''
7371
if save_dirname is not None:
7472
trainer.save_params(save_dirname)
75-
return
73+
trainer.stop()
7674

7775
trainer.train(
7876
reader=train_reader,

python/paddle/fluid/trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self,
100100
param_path=None,
101101
place=None,
102102
parallel=False):
103+
self.__stop = False
103104
self.parallel = parallel
104105
# 1. we need to generate a framework.Program by calling
105106
# program_func. Reference: fluid.program_guard in
@@ -210,6 +211,12 @@ def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
210211
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
211212
)
212213

214+
def stop(self):
215+
"""
216+
stop training
217+
"""
218+
self.__stop = True
219+
213220
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
214221
"""
215222
Train the model.
@@ -289,6 +296,8 @@ def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
289296
for epoch_id in range(num_epochs):
290297
event_handler(BeginEpochEvent(epoch_id))
291298
for step_id, data in enumerate(reader()):
299+
if self.__stop:
300+
return
292301
begin_event = BeginStepEvent(epoch_id, step_id)
293302
event_handler(begin_event)
294303
if begin_event.fetch_metrics:
@@ -327,9 +336,7 @@ def _train_by_parallel_executor(self, num_epochs, event_handler, reader,
327336
feeder = data_feeder.DataFeeder(
328337
feed_list=feed_var_list, place=self.place)
329338
reader = feeder.decorate_reader(reader, multi_devices=True)
330-
for epoch_id in range(num_epochs):
331-
self._train_by_any_executor(event_handler, pe, num_epochs,
332-
reader)
339+
self._train_by_any_executor(event_handler, pe, num_epochs, reader)
333340

334341
def _get_parallel_executor(self):
335342
return getattr(self, 'parallel_executor', None)

0 commit comments

Comments
 (0)