You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: doc/using_mxnet.rst
+33-1Lines changed: 33 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -136,7 +136,8 @@ You don't have to use all the arguments, arguments you don't care about can be i
136
136
**Note: Writing a training script that imports correctly:**
137
137
When SageMaker runs your training script, it imports it as a Python module and then invokes ``train`` on the imported module. Consequently, you should not include any statements that won't execute successfully in SageMaker when your module is imported. For example, don't attempt to open any local files in top-level statements in your training script.
138
138
139
-
If you want to run your training script locally via the Python interpreter, look at using a ``___name__ == '__main__'`` guard, discussed in more detail here: https://stackoverflow.com/questions/419163/what-does-if-name-main-do .
139
+
If you want to run your training script locally by using the Python interpreter, use a ``___name__ == '__main__'`` guard.
140
+
For more information, see https://stackoverflow.com/questions/419163/what-does-if-name-main-do.
140
141
141
142
Save the Model
142
143
--------------
@@ -194,6 +195,37 @@ After your ``train`` function completes, SageMaker will invoke ``save`` with the
194
195
195
196
If your train function returns a Gluon API ``net`` object as its model, you'll need to write your own ``save`` function. You will want to serialize the ``net`` parameters. Saving ``net`` parameters is covered in the `Serialization section <http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html>`__ of the collaborative Gluon deep-learning book `"The Straight Dope" <http://gluon.mxnet.io/index.html>`__.
196
197
198
+
Save a Checkpoint
199
+
-----------------
200
+
201
+
It is good practice to save the best model after each training epoch,
202
+
so that you can resume a training job if it gets interrupted.
203
+
This is particularly important if you are using Managed Spot training.
204
+
205
+
To save MXNet model checkpoints, do the following in your training script:
206
+
207
+
* Set the ``CHECKPOINTS_DIR`` environment variable and enable checkpoints.
* Make sure you are emitting a validation metric to test the model. For information, see `Evaluation Metric API <https://mxnet.incubator.apache.org/api/python/metric/metric.html>`_.
215
+
* After each training epoch, test whether the current model performs the best with respect to the validation metric, and if it does, save that model to ``CHECKPOINTS_DIR``.
216
+
217
+
.. code:: python
218
+
219
+
if checkpoints_enabled and current_host == hosts[0]:
220
+
if val_acc > best_accuracy:
221
+
best_accuracy = val_acc
222
+
logging.info('Saving the model, params and optimizer state')
For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.
0 commit comments