diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 0055416327..8cd6410ea0 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -186,6 +186,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -560,6 +561,21 @@ def __init__( Specifies whether SessionTagChaining is enabled for the training job. training_plan (str or PipelineVariable): Optional. Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } + """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -813,6 +829,8 @@ def __init__( self.training_plan = training_plan + self.instance_placement_config = instance_placement_config + # Internal flag self._is_output_path_set_from_default_bucket_and_prefix = False @@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na if "TrainingPlanArn" in job_details["ResourceConfig"]: init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"] + if "InstancePlacementConfig" in job_details["ResourceConfig"]: + init_params["instance_placement_config"] = job_details["ResourceConfig"][ + "InstancePlacementConfig" + ] + has_hps = "HyperParameters" in job_details init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {} @@ -2882,6 +2905,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3249,6 +3273,20 @@ def __init__( Specifies whether SessionTagChaining is enabled for the training job training_plan (str or PipelineVariable): Optional. Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3303,6 +3341,7 @@ def __init__( enable_remote_debug=enable_remote_debug, enable_session_tag_chaining=enable_session_tag_chaining, training_plan=training_plan, + instance_placement_config=instance_placement_config, **kwargs, ) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 1ad7e3b981..6917421c04 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -85,6 +85,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): estimator.volume_kms_key, estimator.keep_alive_period_in_seconds, estimator.training_plan, + estimator.instance_placement_config, ) stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait) vpc_config = estimator.get_vpc_config() @@ -333,6 +334,7 @@ def _prepare_resource_config( volume_kms_key, keep_alive_period_in_seconds, training_plan, + instance_placement_config=None, ): """Placeholder docstring""" resource_config = { @@ -360,6 +362,8 @@ def _prepare_resource_config( resource_config["InstanceType"] = instance_type if training_plan is not None: resource_config["TrainingPlanArn"] = training_plan + if instance_placement_config is not None: + resource_config["InstancePlacementConfig"] = instance_placement_config return resource_config diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4daf9b1810..e61e1c49a5 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -119,6 +119,7 @@ def __init__( config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ): """Initializes a ``JumpStartEstimator``. @@ -517,6 +518,20 @@ def __init__( Specifies whether SessionTagChaining is enabled for the training job training_plan (str or PipelineVariable): Optional. Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -606,6 +621,7 @@ def _validate_model_id_and_get_type_hook(): config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, training_plan=training_plan, + instance_placement_config=instance_placement_config, ) self.hub_arn = estimator_init_kwargs.hub_arn diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 051cda0f4a..81e1356050 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -145,6 +145,7 @@ def get_init_kwargs( config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -207,6 +208,7 @@ def get_init_kwargs( config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, training_plan=training_plan, + instance_placement_config=instance_placement_config, ) estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5b45b21bd8..f545425a51 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2445,6 +2445,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_reference_arn", "specs", "training_plan", + "instance_placement_config", ] SERIALIZATION_EXCLUSION_SET = { @@ -2519,6 +2520,7 @@ def __init__( config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2582,6 +2584,7 @@ def __init__( self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining self.training_plan = training_plan + self.instance_placement_config = instance_placement_config class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cfb243b563..1698da3e90 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -76,6 +76,8 @@ ) from sagemaker.model_life_cycle import ModelLifeCycle +from tests.unit.test_job import INSTANCE_PLACEMENT_CONFIG + MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" ENTRY_POINT = "blah.py" @@ -879,6 +881,22 @@ def test_framework_with_training_plan(sagemaker_session): assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN +def test_framework_with_instance_placement(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type="ml.c4.xlarge", + instance_count=2, + training_plan=TRAINING_PLAN, + instance_placement_config=INSTANCE_PLACEMENT_CONFIG, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["InstancePlacementConfig"] == INSTANCE_PLACEMENT_CONFIG + + def test_framework_with_both_training_repository_config(sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index dc21f50b68..cdd4a2630e 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -32,6 +32,10 @@ INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD = 1800 TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" +INSTANCE_PLACEMENT_CONFIG = { + "EnableMultipleJobs": True, + "PlacementSpecifications": [{"UltraServerId": "us-1", "InstanceCount": "2"}], +} INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1) VOLUME_SIZE = 1 MAX_RUNTIME = 1 @@ -756,6 +760,28 @@ def test_prepare_resource_config_with_training_plan(): } +def test_prepare_resource_config_with_placement_config(): + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + TRAINING_PLAN, + INSTANCE_PLACEMENT_CONFIG, + ) + + assert resource_config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, + "TrainingPlanArn": TRAINING_PLAN, + "InstancePlacementConfig": INSTANCE_PLACEMENT_CONFIG, + } + + def test_prepare_resource_config_with_keep_alive_period(): resource_config = _Job._prepare_resource_config( INSTANCE_COUNT,