Skip to content

Commit f97c5d4

Browse files
committed
Trainer documentation
1 parent 08995ac commit f97c5d4

File tree

1 file changed

+69
-9
lines changed

1 file changed

+69
-9
lines changed

python/paddle/fluid/trainer.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,62 @@ def check_and_get_place(place):
151151

152152
class Trainer(object):
153153
"""
154+
A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
155+
simple neural network easily.
156+
157+
This API takes a :code:`train_func`. A :code:`train_func` is a function that
158+
return loss as it first return value. The reset value can be fetched by
159+
EndStepEvent.metrics
160+
161+
This API also takes a :code:`optimizer_func` that will return an optimizer
162+
instance.
163+
164+
For example, to train a MLP for MNIST dataset, the sample program is
165+
166+
>>> import paddle.fluid as fluid
167+
>>>
168+
>>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
169+
>>> hidden = image
170+
>>> for layer_size in layer_sizes:
171+
>>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
172+
>>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
173+
>>>
174+
>>> def train_mnist_mlp():
175+
>>> img = fluid.layers.data(name='image', shape=[784])
176+
>>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
177+
>>> prediction = mlp(img)
178+
>>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
179+
>>>
180+
>>> def optimizer():
181+
>>> return fluid.optimizer.Adam()
182+
>>>
183+
>>> trainer = Trainer(train_func=train_mnist_mlp,
184+
>>> optimizer_func=optimizer,
185+
>>> place=fluid.CUDAPlace(0),
186+
>>> parallel=True)
187+
>>>
188+
>>> def train_callback(event):
189+
>>> if isinstance(event, fluid.EndStepEvent):
190+
>>> print "Epoch ID", event.epoch, "Step ID",\
191+
>>> event.step, "AvgLoss", event.metrics[0]
192+
>>> elif isinstance(event, fluid.EndEpochEvent):
193+
>>> trainer.save_params("./model_{0}".format(event.epoch))
194+
>>>
195+
>>> trainer.train(num_epochs=100, event_handler=train_callback)
196+
197+
For more example, please see :ref:`api_guide_high_level_api`.
198+
154199
155200
Args:
156-
train_func(callable): A function which will return loss. The loss must be a scalar.
201+
train_func(callable): A function which will return loss. The loss must be
202+
a scalar tensor.
157203
optimizer_func(callable): A function that returns an Optimizer object.
158-
place: The device place of this trainer.
204+
place(CUDAPlace|CPUPlace): The device place of this trainer. If
205+
:code:`parallel=True,` all CUDA Places will be used if :code:`place`
206+
is a :code:`CUDAPlace`.
207+
parallel(bool): True if use multiple devices.
208+
checkpoint_config(CheckpointConfig): Configuration about how to save
209+
checkpoints.
159210
"""
160211

161212
def __init__(self,
@@ -167,9 +218,6 @@ def __init__(self,
167218
checkpoint_config=None):
168219
self.__stop = False
169220
self.parallel = parallel
170-
# 1. we need to generate a framework.Program by calling
171-
# program_func. Reference: fluid.program_guard in
172-
# test_word2vec.py
173221

174222
# config for checkpoint
175223
# only chief worker will save variables
@@ -183,6 +231,10 @@ def __init__(self,
183231

184232
self.scope = core.Scope()
185233

234+
# 1. we need to generate a framework.Program by calling
235+
# program_func. Reference: fluid.program_guard in
236+
# test_word2vec.py
237+
186238
self.startup_program = framework.Program()
187239
self.train_program = framework.Program()
188240

@@ -315,17 +367,18 @@ def stop(self):
315367

316368
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
317369
"""
318-
Train the model.
370+
Start the train loop to train the model.
319371
320372
Args:
321373
num_epochs: The number of epoch. An epoch will process all data in reader
322374
event_handler: The event handler. A function with type (ev:Event)->void
323-
reader:
375+
reader: A reader creator object. See also
376+
:ref:`api_guide_python_reader` .
324377
feed_order: Feeding order of reader. None will following the defining
325378
order in program
326379
327380
Returns:
328-
381+
None
329382
"""
330383
training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
331384
if training_role == "PSERVER":
@@ -354,7 +407,14 @@ def test(self, reader, feed_order):
354407
self.train_func_outputs)
355408

356409
def save_params(self, param_path):
357-
# reference: save_persistables in io.py
410+
"""
411+
Save all parameters into :code:`param_path`
412+
Args:
413+
param_path(str): The path to save parameters
414+
415+
Returns:
416+
None
417+
"""
358418
with self._prog_and_scope_guard():
359419
exe = executor.Executor(self.place)
360420
io.save_persistables(exe, dirname=param_path)

0 commit comments

Comments
 (0)