@@ -151,11 +151,62 @@ def check_and_get_place(place):
151
151
152
152
class Trainer (object ):
153
153
"""
154
+ A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
155
+ simple neural network easily.
156
+
157
+ This API takes a :code:`train_func`. A :code:`train_func` is a function that
158
+ return loss as it first return value. The reset value can be fetched by
159
+ EndStepEvent.metrics
160
+
161
+ This API also takes a :code:`optimizer_func` that will return an optimizer
162
+ instance.
163
+
164
+ For example, to train a MLP for MNIST dataset, the sample program is
165
+
166
+ >>> import paddle.fluid as fluid
167
+ >>>
168
+ >>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
169
+ >>> hidden = image
170
+ >>> for layer_size in layer_sizes:
171
+ >>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
172
+ >>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
173
+ >>>
174
+ >>> def train_mnist_mlp():
175
+ >>> img = fluid.layers.data(name='image', shape=[784])
176
+ >>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
177
+ >>> prediction = mlp(img)
178
+ >>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
179
+ >>>
180
+ >>> def optimizer():
181
+ >>> return fluid.optimizer.Adam()
182
+ >>>
183
+ >>> trainer = Trainer(train_func=train_mnist_mlp,
184
+ >>> optimizer_func=optimizer,
185
+ >>> place=fluid.CUDAPlace(0),
186
+ >>> parallel=True)
187
+ >>>
188
+ >>> def train_callback(event):
189
+ >>> if isinstance(event, fluid.EndStepEvent):
190
+ >>> print "Epoch ID", event.epoch, "Step ID",\
191
+ >>> event.step, "AvgLoss", event.metrics[0]
192
+ >>> elif isinstance(event, fluid.EndEpochEvent):
193
+ >>> trainer.save_params("./model_{0}".format(event.epoch))
194
+ >>>
195
+ >>> trainer.train(num_epochs=100, event_handler=train_callback)
196
+
197
+ For more example, please see :ref:`api_guide_high_level_api`.
198
+
154
199
155
200
Args:
156
- train_func(callable): A function which will return loss. The loss must be a scalar.
201
+ train_func(callable): A function which will return loss. The loss must be
202
+ a scalar tensor.
157
203
optimizer_func(callable): A function that returns an Optimizer object.
158
- place: The device place of this trainer.
204
+ place(CUDAPlace|CPUPlace): The device place of this trainer. If
205
+ :code:`parallel=True,` all CUDA Places will be used if :code:`place`
206
+ is a :code:`CUDAPlace`.
207
+ parallel(bool): True if use multiple devices.
208
+ checkpoint_config(CheckpointConfig): Configuration about how to save
209
+ checkpoints.
159
210
"""
160
211
161
212
def __init__ (self ,
@@ -167,9 +218,6 @@ def __init__(self,
167
218
checkpoint_config = None ):
168
219
self .__stop = False
169
220
self .parallel = parallel
170
- # 1. we need to generate a framework.Program by calling
171
- # program_func. Reference: fluid.program_guard in
172
- # test_word2vec.py
173
221
174
222
# config for checkpoint
175
223
# only chief worker will save variables
@@ -183,6 +231,10 @@ def __init__(self,
183
231
184
232
self .scope = core .Scope ()
185
233
234
+ # 1. we need to generate a framework.Program by calling
235
+ # program_func. Reference: fluid.program_guard in
236
+ # test_word2vec.py
237
+
186
238
self .startup_program = framework .Program ()
187
239
self .train_program = framework .Program ()
188
240
@@ -315,17 +367,18 @@ def stop(self):
315
367
316
368
def train (self , num_epochs , event_handler , reader = None , feed_order = None ):
317
369
"""
318
- Train the model.
370
+ Start the train loop to train the model.
319
371
320
372
Args:
321
373
num_epochs: The number of epoch. An epoch will process all data in reader
322
374
event_handler: The event handler. A function with type (ev:Event)->void
323
- reader:
375
+ reader: A reader creator object. See also
376
+ :ref:`api_guide_python_reader` .
324
377
feed_order: Feeding order of reader. None will following the defining
325
378
order in program
326
379
327
380
Returns:
328
-
381
+ None
329
382
"""
330
383
training_role = os .getenv ("PADDLE_TRAINING_ROLE" , "" )
331
384
if training_role == "PSERVER" :
@@ -354,7 +407,14 @@ def test(self, reader, feed_order):
354
407
self .train_func_outputs )
355
408
356
409
def save_params (self , param_path ):
357
- # reference: save_persistables in io.py
410
+ """
411
+ Save all parameters into :code:`param_path`
412
+ Args:
413
+ param_path(str): The path to save parameters
414
+
415
+ Returns:
416
+ None
417
+ """
358
418
with self ._prog_and_scope_guard ():
359
419
exe = executor .Executor (self .place )
360
420
io .save_persistables (exe , dirname = param_path )
0 commit comments