1717import re
1818from typing import Optional , Union , Dict
1919
20- from sagemaker .deprecations import renamed_kwargs
2120from sagemaker .estimator import Framework , EstimatorBase
2221from sagemaker .fw_utils import (
2322 framework_name_from_image ,
24- warn_if_parameter_server_with_multi_gpu ,
25- validate_smdistributed ,
23+ validate_distribution ,
2624)
2725from sagemaker .huggingface .model import HuggingFaceModel
2826from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
@@ -37,6 +35,9 @@ class HuggingFace(Framework):
3735 """Handle training of custom HuggingFace code."""
3836
3937 _framework_name = "huggingface"
38+ LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
39+ LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
40+ INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4041
4142 def __init__ (
4243 self ,
@@ -142,6 +143,36 @@ def __init__(
142143 }
143144 }
144145
146+ **To enable PyTorch DDP:**
147+
148+ .. code:: python
149+
150+ {
151+ "pytorchddp": {
152+ "enabled": True
153+ }
154+ }
155+
156+ To learn more, see `Distributed PyTorch Training
157+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_.
158+
159+ **To enable Torch Distributed:**
160+
161+ This is available for general distributed training on
162+ GPU instances from PyTorch v1.13.1 and later.
163+
164+ .. code:: python
165+
166+ {
167+ "torch_distributed": {
168+ "enabled": True
169+ }
170+ }
171+
172+ This option also supports distributed training on Trn1.
173+ To learn more, see `Distributed PyTorch Training on Trainium
174+ <https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training-on-trainium>`_.
175+
145176 To enable distributed training with
146177 `SageMaker Training Compiler <https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
147178 for Hugging Face Transformers with PyTorch:
@@ -182,29 +213,6 @@ def __init__(
182213
183214 self ._validate_args (image_uri = image_uri )
184215
185- instance_type = renamed_kwargs (
186- "train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs
187- )
188-
189- base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
190- base_framework_version = (
191- tensorflow_version if tensorflow_version is not None else pytorch_version
192- )
193-
194- if distribution is not None :
195- validate_smdistributed (
196- instance_type = instance_type ,
197- framework_name = base_framework_name ,
198- framework_version = base_framework_version ,
199- py_version = self .py_version ,
200- distribution = distribution ,
201- image_uri = image_uri ,
202- )
203-
204- warn_if_parameter_server_with_multi_gpu (
205- training_instance_type = instance_type , distribution = distribution
206- )
207-
208216 if "enable_sagemaker_metrics" not in kwargs :
209217 kwargs ["enable_sagemaker_metrics" ] = True
210218
@@ -214,6 +222,25 @@ def __init__(
214222 entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
215223 )
216224
225+ if "entry_point" not in kwargs :
226+ kwargs ["entry_point" ] = entry_point
227+
228+ self .base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
229+ self .base_framework_version = (
230+ tensorflow_version if tensorflow_version is not None else pytorch_version
231+ )
232+
233+ if distribution is not None :
234+ distribution = validate_distribution (
235+ distribution ,
236+ self .instance_groups ,
237+ self .base_framework_name ,
238+ self .base_framework_version ,
239+ py_version ,
240+ image_uri ,
241+ kwargs ,
242+ )
243+
217244 self .distribution = distribution or {}
218245
219246 if compiler_config is not None :
@@ -267,14 +294,44 @@ def _validate_args(self, image_uri):
267294 "transformers_version, tensorflow_version and pytorch_version."
268295 )
269296
297+ def _huggingface_distribution_configuration (self , distribution ):
298+ """Returns a dict of distribution config for Hugging Face training
299+
300+ Args:
301+ distribution (dict): A dictionary with information on how to run distributed training.
302+ Returns:
303+ dict containing Pytorch DDP config
304+ """
305+ distribution_config = {}
306+ pytorch_ddp_enabled = False
307+ torch_distributed_enabled = False
308+
309+ if "pytorchddp" in distribution :
310+ pytorch_ddp_enabled = distribution .get ("pytorchddp" ).get ("enabled" , False )
311+ elif "torch_distributed" in distribution :
312+ torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
313+
314+ if pytorch_ddp_enabled :
315+ distribution_config [self .LAUNCH_PYTORCH_DDP_ENV_NAME ] = pytorch_ddp_enabled
316+ if self .instance_type is not None :
317+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
318+ elif torch_distributed_enabled :
319+ distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
320+ if self .instance_type is not None :
321+ distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
322+ else :
323+ distribution_config = self ._distribution_configuration (distribution = distribution )
324+
325+ return distribution_config
326+
270327 def hyperparameters (self ):
271328 """Return hyperparameters used by your custom PyTorch code during model training."""
272329 hyperparameters = super (HuggingFace , self ).hyperparameters ()
273- distributed_training_hyperparameters = self ._distribution_configuration (
330+ additional_hyperparameters = self ._huggingface_distribution_configuration (
274331 distribution = self .distribution
275332 )
276333 hyperparameters .update (
277- EstimatorBase ._json_encode_hyperparameters (distributed_training_hyperparameters )
334+ EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
278335 )
279336
280337 if self .compiler_config :
0 commit comments