@@ -196,6 +196,7 @@ fit Optional Arguments
196196- ``logs ``: Defaults to True, whether to show logs produced by training
197197 job in the Python session. Only meaningful when wait is True.
198198
199+ ----
199200
200201Distributed PyTorch Training
201202============================
@@ -262,16 +263,18 @@ during the PyTorch DDP initialization.
262263
263264.. note ::
264265
265- The SageMaker PyTorch estimator operates ``mpirun `` in the backend.
266- It doesn’t use ``torchrun `` for distributed training.
266+ The SageMaker PyTorch estimator can operate both ``mpirun `` (for PyTorch 1.12.0 and later)
267+ and ``torchrun `` (for PyTorch 1.13.1 and later) in the backend for distributed training.
267268
268269For more information about setting up PyTorch DDP in your training script,
269270see `Getting Started with Distributed Data Parallel
270271<https://pytorch.org/tutorials/intermediate/ddp_tutorial.html> `_ in the
271272PyTorch documentation.
272273
273- The following example shows how to run a PyTorch DDP training in SageMaker
274- using two ``ml.p4d.24xlarge `` instances:
274+ The following examples show how to set a PyTorch estimator
275+ to run a distributed training job on two ``ml.p4d.24xlarge `` instances.
276+
277+ **Using PyTorch DDP with the mpirun backend **
275278
276279.. code :: python
277280
@@ -291,7 +294,34 @@ using two ``ml.p4d.24xlarge`` instances:
291294 }
292295 )
293296
294- pt_estimator.fit(" s3://bucket/path/to/training/data" )
297+ **Using PyTorch DDP with the torchrun backend **
298+
299+ .. code :: python
300+
301+ from sagemaker.pytorch import PyTorch
302+
303+ pt_estimator = PyTorch(
304+ entry_point = " train_ptddp.py" ,
305+ role = " SageMakerRole" ,
306+ framework_version = " 1.13.1" ,
307+ py_version = " py38" ,
308+ instance_count = 2 ,
309+ instance_type = " ml.p4d.24xlarge" ,
310+ distribution = {
311+ " torch_distributed" : {
312+ " enabled" : True
313+ }
314+ }
315+ )
316+
317+
318+ .. note ::
319+
320+ For more information about setting up ``torchrun `` in your training script,
321+ see `torchrun (Elastic Launch) <https://pytorch.org/docs/stable/elastic/run.html >`_ in *the
322+ PyTorch documentation *.
323+
324+ ----
295325
296326.. _distributed-pytorch-training-on-trainium :
297327
@@ -324,7 +354,7 @@ with the ``torch_distributed`` option as the distribution strategy.
324354
325355.. note ::
326356
327- SageMaker Debugger is currently not supported with Trn1 instances.
357+ SageMaker Debugger is not compatible with Trn1 instances.
328358
329359Adapt Your Training Script to Initialize with the XLA backend
330360-------------------------------------------------------------
0 commit comments