@@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
729729 * ``num_kept_partial_checkpoints `` (int) (default: None): The maximum number
730730 of partial checkpoints to keep on disk.
731731
732- .. function :: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states =True, translate_function=None)
732+ .. function :: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state =True, translate_function=None)
733733
734734 While :class: `smdistributed.modelparallel.torch.load ` loads saved
735735 model and optimizer objects, this function resumes from a saved checkpoint file.
@@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
742742 * ``partial `` (boolean) (default: True): Whether to load the partial checkpoint.
743743 * ``strict `` (boolean) (default: True): Load with strict load, no extra key or
744744 missing key is allowed.
745- * ``load_optimizer_states `` (boolean) (default: True): Whether to load ``optimizer_states ``.
745+ * ``load_optimizer `` (boolean) (default: True): Whether to load ``optimizer ``.
746+ * ``load_sharded_optimizer_state `` (boolean) (default: True): Whether to load
747+ the sharded optimizer state of a model.
748+ It can be used only when you activate
749+ the `sharded data parallelism
750+ <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html> `_
751+ feature of the SageMaker model parallel library.
752+ When this is ``False ``, the library only loads the FP16
753+ states, such as FP32 master parameters and the loss scaling factor,
754+ not the sharded optimizer states.
746755 * ``translate_function `` (function) (default: None): function to translate the full
747756 checkpoint into smdistributed.modelparallel format.
748757 For supported models, this is not required.
0 commit comments