@@ -293,6 +293,121 @@ using two ``ml.p4d.24xlarge`` instances:
293293
294294 pt_estimator.fit(" s3://bucket/path/to/training/data" )
295295
296+ .. _distributed-pytorch-training-on-trainium :
297+
298+ Distributed Training with PyTorch Neuron on Trn1 instances
299+ ==========================================================
300+
301+ SageMaker Training supports Amazon EC2 Trn1 instances powered by
302+ `AWS Trainium <https://aws.amazon.com/machine-learning/trainium/ >`_ device,
303+ the second generation purpose-built machine learning accelerator from AWS.
304+ Each Trn1 instance consists of up to 16 Trainium devices, and each
305+ Trainium device consists of two `NeuronCores
306+ <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trn1-arch.html#trainium-architecture> `_
307+ in the *AWS Neuron Documentation *.
308+
309+ You can run distributed training job on Trn1 instances.
310+ SageMaker supports the ``xla `` package through ``torchrun ``.
311+ With this, you do not need to manually pass ``RANK ``,
312+ ``WORLD_SIZE ``, ``MASTER_ADDR ``, and ``MASTER_PORT ``.
313+ You can launch the training job using the
314+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class
315+ with the ``torch_distributed `` option as the distribution strategy.
316+
317+ .. note ::
318+
319+ This ``torch_distributed `` support is available
320+ in the AWS Deep Learning Containers for PyTorch Neuron starting v1.11.0.
321+ To find a complete list of supported versions of PyTorch Neuron, see
322+ `Neuron Containers <https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers >`_
323+ in the *AWS Deep Learning Containers GitHub repository *.
324+
325+ .. note ::
326+
327+ SageMaker Debugger is currently not supported with Trn1 instances.
328+
329+ Adapt Your Training Script to Initialize with the XLA backend
330+ -------------------------------------------------------------
331+
332+ To initialize distributed training in your script, call
333+ `torch.distributed.init_process_group
334+ <https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group> `_
335+ with the ``xla `` backend as shown below.
336+
337+ .. code :: python
338+
339+ import torch.distributed as dist
340+
341+ dist.init_process_group(' xla' )
342+
343+ SageMaker takes care of ``'MASTER_ADDR' `` and ``'MASTER_PORT' `` for you via ``torchrun ``
344+
345+ For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/mlp.html?highlight=torchrun#multi-worker-data-parallel-mlp-training-using-torchrun >`_ in the *AWS Neuron Documentation *.
346+
347+ **Currently Supported backends: **
348+
349+ - ``xla `` for Trainium (Trn1) instances
350+
351+ For up-to-date information on supported backends for Trn1 instances, see `AWS Neuron Documentation <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html >`_.
352+
353+ Launching a Distributed Training Job on Trainium
354+ ------------------------------------------------
355+
356+ You can run multi-node distributed PyTorch training jobs on Trn1 instances using the
357+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class.
358+ With ``instance_count=1 ``, the estimator submits a
359+ single-node training job to SageMaker; with ``instance_count `` greater
360+ than one, a multi-node training job is launched.
361+
362+ With the ``torch_distributed `` option, the SageMaker PyTorch estimator runs a SageMaker
363+ training container for PyTorch Neuron, sets up the environment, and launches
364+ the training job using the ``torchrun `` command on each worker with the given information.
365+
366+ **Examples **
367+
368+ The following examples show how to run a PyTorch training using ``torch_distributed `` in SageMaker
369+ on one ``ml.trn1.2xlarge `` instance and two ``ml.trn1.32xlarge `` instances:
370+
371+ .. code :: python
372+
373+ from sagemaker.pytorch import PyTorch
374+
375+ pt_estimator = PyTorch(
376+ entry_point = " train_torch_distributed.py" ,
377+ role = " SageMakerRole" ,
378+ framework_version = " 1.11.0" ,
379+ py_version = " py38" ,
380+ instance_count = 1 ,
381+ instance_type = " ml.trn1.2xlarge" ,
382+ distribution = {
383+ " torch_distributed" : {
384+ " enabled" : True
385+ }
386+ }
387+ )
388+
389+ pt_estimator.fit(" s3://bucket/path/to/training/data" )
390+
391+ .. code :: python
392+
393+ from sagemaker.pytorch import PyTorch
394+
395+ pt_estimator = PyTorch(
396+ entry_point = " train_torch_distributed.py" ,
397+ role = " SageMakerRole" ,
398+ framework_version = " 1.11.0" ,
399+ py_version = " py38" ,
400+ instance_count = 2 ,
401+ instance_type = " ml.trn1.32xlarge" ,
402+ distribution = {
403+ " torch_distributed" : {
404+ " enabled" : True
405+ }
406+ }
407+ )
408+
409+ pt_estimator.fit(" s3://bucket/path/to/training/data" )
410+
296411*********************
297412Deploy PyTorch Models
298413*********************
0 commit comments