2020import uuid
2121from abc import ABCMeta , abstractmethod
2222from typing import Any , Dict , Union , Optional , List
23+ from packaging .specifiers import SpecifierSet
24+ from packaging .version import Version
2325
2426from six import string_types , with_metaclass
2527from six .moves .urllib .parse import urlparse
8385)
8486from sagemaker .workflow import is_pipeline_variable
8587from sagemaker .workflow .entities import PipelineVariable
86- from sagemaker .workflow .pipeline_context import (
87- PipelineSession ,
88- runnable_by_pipeline ,
89- )
88+ from sagemaker .workflow .pipeline_context import PipelineSession , runnable_by_pipeline
9089
9190logger = logging .getLogger (__name__ )
9291
@@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
106105 LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
107106 LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
108107 LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
108+ LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled"
109109 INSTANCE_TYPE = "sagemaker_instance_type"
110110 MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
111111 MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
@@ -557,9 +557,7 @@ def __init__(
557557 self .dependencies = dependencies or []
558558 self .uploaded_code = None
559559 self .tags = add_jumpstart_tags (
560- tags = tags ,
561- training_model_uri = self .model_uri ,
562- training_script_uri = self .source_dir ,
560+ tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
563561 )
564562 if self .instance_type in ("local" , "local_gpu" ):
565563 if self .instance_type == "local_gpu" and self .instance_count > 1 :
@@ -680,8 +678,7 @@ def _ensure_base_job_name(self):
680678 self .base_job_name
681679 or get_jumpstart_base_name_if_jumpstart_model (self .source_dir , self .model_uri )
682680 or base_name_from_image (
683- self .training_image_uri (),
684- default_base_name = EstimatorBase .JOB_CLASS_NAME ,
681+ self .training_image_uri (), default_base_name = EstimatorBase .JOB_CLASS_NAME
685682 )
686683 )
687684
@@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
744741 self .dependencies = updated_paths ["dependencies" ]
745742
746743 if self .source_dir or self .entry_point or self .dependencies :
747-
748744 # validate source dir will raise a ValueError if there is something wrong with
749745 # the source directory. We are intentionally not handling it because this is a
750746 # critical error.
@@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule):
10231019 parse_result = urlparse (rule .rule_parameters ["source_s3_uri" ])
10241020 if parse_result .scheme != "s3" :
10251021 desired_s3_uri = os .path .join (
1026- "s3://" ,
1027- self .sagemaker_session .default_bucket (),
1028- rule .name ,
1029- str (uuid .uuid4 ()),
1022+ "s3://" , self .sagemaker_session .default_bucket (), rule .name , str (uuid .uuid4 ())
10301023 )
10311024 s3_uri = S3Uploader .upload (
10321025 local_path = rule .rule_parameters ["source_s3_uri" ],
@@ -1439,10 +1432,7 @@ def deploy(
14391432 self ._ensure_base_job_name ()
14401433
14411434 jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model (
1442- kwargs .get ("source_dir" ),
1443- self .source_dir ,
1444- kwargs .get ("model_data" ),
1445- self .model_uri ,
1435+ kwargs .get ("source_dir" ), self .source_dir , kwargs .get ("model_data" ), self .model_uri
14461436 )
14471437 default_name = (
14481438 name_from_base (jumpstart_base_name )
@@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri):
22402230
22412231 @classmethod
22422232 def update (
2243- cls ,
2244- estimator ,
2245- profiler_rule_configs = None ,
2246- profiler_config = None ,
2247- resource_config = None ,
2233+ cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
22482234 ):
22492235 """Update a running Amazon SageMaker training job.
22502236
@@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self):
31653151 )
31663152 self .debugger_hook_config = False
31673153
3154+ def _validate_mwms_config (self , distribution ):
3155+ """Validate Multi Worker Mirrored Strategy configuration."""
3156+ minimum_supported_framework_version = {"tensorflow" : {"framework_version" : "2.9" }}
3157+ if self ._framework_name in minimum_supported_framework_version :
3158+ for version_argument in minimum_supported_framework_version [self ._framework_name ]:
3159+ current = getattr (self , version_argument )
3160+ threshold = minimum_supported_framework_version [self ._framework_name ][
3161+ version_argument
3162+ ]
3163+ if Version (current ) in SpecifierSet (f"< { threshold } " ):
3164+ raise ValueError (
3165+ "Multi Worker Mirrored Strategy is only supported "
3166+ "from {} {} but received {}" .format (version_argument , threshold , current )
3167+ )
3168+ else :
3169+ raise ValueError (
3170+ "Multi Worker Mirrored Strategy is currently only supported "
3171+ "with {} frameworks but received {}" .format (
3172+ minimum_supported_framework_version .keys (), self ._framework_name
3173+ )
3174+ )
3175+ unsupported_distributions = ["smdistributed" , "parameter_server" ]
3176+ if any (i in distribution for i in unsupported_distributions ):
3177+ raise ValueError (
3178+ "Multi Worker Mirrored Strategy is currently not supported with the"
3179+ " following distribution strategies: {}" .format (unsupported_distributions )
3180+ )
3181+
31683182 def _model_source_dir (self ):
31693183 """Get the appropriate value to pass as ``source_dir`` to a model constructor.
31703184
@@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution):
35283542 "dataparallel"
35293543 ].get ("custom_mpi_options" , "" )
35303544
3545+ if "multi_worker_mirrored_strategy" in distribution :
3546+ mwms_enabled = distribution .get ("multi_worker_mirrored_strategy" ).get ("enabled" , False )
3547+ if mwms_enabled :
3548+ self ._validate_mwms_config (distribution )
3549+ distribution_config [self .LAUNCH_MWMS_ENV_NAME ] = mwms_enabled
3550+
35313551 if not (mpi_enabled or smdataparallel_enabled ) and distribution_config .get (
35323552 "sagemaker_distribution_instance_groups"
35333553 ) not in [None , []]:
0 commit comments