Skip to content

Commit 4dc6a5e

Browse files
eslesar-awsknakad
authored andcommitted
doc: add checkpoint section to using_mxnet topic (#1008)
Added a section with instructions and example of saving checkpoints during training with MXNet.
1 parent b7db41f commit 4dc6a5e

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

doc/using_mxnet.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ You don't have to use all the arguments, arguments you don't care about can be i
136136
**Note: Writing a training script that imports correctly:**
137137
When SageMaker runs your training script, it imports it as a Python module and then invokes ``train`` on the imported module. Consequently, you should not include any statements that won't execute successfully in SageMaker when your module is imported. For example, don't attempt to open any local files in top-level statements in your training script.
138138

139-
If you want to run your training script locally via the Python interpreter, look at using a ``___name__ == '__main__'`` guard, discussed in more detail here: https://stackoverflow.com/questions/419163/what-does-if-name-main-do .
139+
If you want to run your training script locally by using the Python interpreter, use a ``___name__ == '__main__'`` guard.
140+
For more information, see https://stackoverflow.com/questions/419163/what-does-if-name-main-do.
140141

141142
Save the Model
142143
--------------
@@ -194,6 +195,37 @@ After your ``train`` function completes, SageMaker will invoke ``save`` with the
194195

195196
If your train function returns a Gluon API ``net`` object as its model, you'll need to write your own ``save`` function. You will want to serialize the ``net`` parameters. Saving ``net`` parameters is covered in the `Serialization section <http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html>`__ of the collaborative Gluon deep-learning book `"The Straight Dope" <http://gluon.mxnet.io/index.html>`__.
196197

198+
Save a Checkpoint
199+
-----------------
200+
201+
It is good practice to save the best model after each training epoch,
202+
so that you can resume a training job if it gets interrupted.
203+
This is particularly important if you are using Managed Spot training.
204+
205+
To save MXNet model checkpoints, do the following in your training script:
206+
207+
* Set the ``CHECKPOINTS_DIR`` environment variable and enable checkpoints.
208+
209+
.. code:: python
210+
211+
CHECKPOINTS_DIR = '/opt/ml/checkpoints'
212+
checkpoints_enabled = os.path.exists(CHECKPOINTS_DIR)
213+
214+
* Make sure you are emitting a validation metric to test the model. For information, see `Evaluation Metric API <https://mxnet.incubator.apache.org/api/python/metric/metric.html>`_.
215+
* After each training epoch, test whether the current model performs the best with respect to the validation metric, and if it does, save that model to ``CHECKPOINTS_DIR``.
216+
217+
.. code:: python
218+
219+
if checkpoints_enabled and current_host == hosts[0]:
220+
if val_acc > best_accuracy:
221+
best_accuracy = val_acc
222+
logging.info('Saving the model, params and optimizer state')
223+
net.export(CHECKPOINTS_DIR + "/%.4f-cifar10"%(best_accuracy), epoch)
224+
trainer.save_states(CHECKPOINTS_DIR + '/%.4f-cifar10-%d.states'%(best_accuracy, epoch))
225+
226+
For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.
227+
228+
197229
Updating your MXNet training script
198230
-----------------------------------
199231

0 commit comments

Comments
 (0)