Skip to content

Commit 958ab99

Browse files
committed
Polish Non-Layer API
1 parent 16a0f74 commit 958ab99

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

python/paddle/fluid/trainer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def __init__(self, epoch_id):
3838

3939

4040
class EndEpochEvent(object):
41+
"""
42+
The end of a training epoch.
43+
44+
Args:
45+
epoch_id(int): The current epoch ID.
46+
"""
47+
4148
def __init__(self, epoch_id):
4249
self.epoch = epoch_id
4350

@@ -50,13 +57,44 @@ def __init__(self, epoch_id, step_id):
5057

5158

5259
class EndStepEvent(object):
60+
"""
61+
The end of a training step.
62+
63+
Args:
64+
epoch_id(int): The current epoch ID.
65+
step_id(int): The current step ID.
66+
metrics(list): A list of fetched tensor. The order of this list is same
67+
as the :code:`train_func` returns.
68+
"""
69+
5370
def __init__(self, epoch_id, step_id, metrics):
5471
self.epoch = epoch_id
5572
self.step = step_id
5673
self.metrics = metrics
5774

5875

5976
class CheckpointConfig(object):
77+
"""
78+
Parameter object for :code:`fluid.io.save_checkpoint` and
79+
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
80+
81+
Args:
82+
checkpoint_dir(str): Directory path to save check point. Default is the
83+
current directory.
84+
85+
max_num_checkpoints(int): The max number of local check points.
86+
epoch_interval(int): Every number of epoch to save check point.
87+
step_interval(int): Every number of step to save check point.
88+
89+
Examples:
90+
>>> config = fluid.CheckpointConfig("./checkpoints")
91+
>>> trainer = fluid.Trainer(train_func=train_program,
92+
>>> place=place,
93+
>>> optimizer_func=optimizer_func,
94+
>>> checkpoint_config=config)
95+
>>> trainer.train(...)
96+
"""
97+
6098
def __init__(self,
6199
checkpoint_dir=None,
62100
max_num_checkpoints=3,

0 commit comments

Comments
 (0)