@@ -186,6 +186,7 @@ def __init__(
186186        enable_remote_debug : Optional [Union [bool , PipelineVariable ]] =  None ,
187187        enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] =  None ,
188188        training_plan : Optional [Union [str , PipelineVariable ]] =  None ,
189+         instance_placement_config : Optional [Dict ] =  None ,
189190        ** kwargs ,
190191    ):
191192        """Initialize an ``EstimatorBase`` instance. 
@@ -560,6 +561,21 @@ def __init__(
560561                Specifies whether SessionTagChaining is enabled for the training job. 
561562            training_plan (str or PipelineVariable): Optional. 
562563                Specifies which training plan arn to use for the training job 
564+             instance_placement_config (dict): Optional. 
565+                 Specifies UltraServer placement configuration for the training job 
566+ 
567+                 .. code:: python 
568+ 
569+                     instance_placement_config={ 
570+                         "EnableMultipleJobs": True, 
571+                         "PlacementSpecifications":[ 
572+                             { 
573+                                 "UltraServerId": "ultraserver-1", 
574+                                 "InstanceCount": "2" 
575+                             } 
576+                         ] 
577+                     } 
578+ 
563579        """ 
564580        instance_count  =  renamed_kwargs (
565581            "train_instance_count" , "instance_count" , instance_count , kwargs 
@@ -813,6 +829,8 @@ def __init__(
813829
814830        self .training_plan  =  training_plan 
815831
832+         self .instance_placement_config  =  instance_placement_config 
833+ 
816834        # Internal flag 
817835        self ._is_output_path_set_from_default_bucket_and_prefix  =  False 
818836
@@ -1997,6 +2015,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19972015        if  "TrainingPlanArn"  in  job_details ["ResourceConfig" ]:
19982016            init_params ["training_plan" ] =  job_details ["ResourceConfig" ]["TrainingPlanArn" ]
19992017
2018+         if  "InstancePlacementConfig"  in  job_details ["ResourceConfig" ]:
2019+             init_params ["instance_placement_config" ] =  job_details ["ResourceConfig" ][
2020+                 "InstancePlacementConfig" 
2021+             ]
2022+ 
20002023        has_hps  =  "HyperParameters"  in  job_details 
20012024        init_params ["hyperparameters" ] =  job_details ["HyperParameters" ] if  has_hps  else  {}
20022025
@@ -2882,6 +2905,7 @@ def __init__(
28822905        enable_remote_debug : Optional [Union [bool , PipelineVariable ]] =  None ,
28832906        enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] =  None ,
28842907        training_plan : Optional [Union [str , PipelineVariable ]] =  None ,
2908+         instance_placement_config : Optional [Dict ] =  None ,
28852909        ** kwargs ,
28862910    ):
28872911        """Initialize an ``Estimator`` instance. 
@@ -3249,6 +3273,20 @@ def __init__(
32493273                 Specifies whether SessionTagChaining is enabled for the training job 
32503274            training_plan (str or PipelineVariable): Optional. 
32513275                Specifies which training plan arn to use for the training job 
3276+             instance_placement_config (dict): Optional. 
3277+                 Specifies UltraServer placement configuration for the training job 
3278+ 
3279+                 .. code:: python 
3280+ 
3281+                     instance_placement_config={ 
3282+                         "EnableMultipleJobs": True, 
3283+                         "PlacementSpecifications":[ 
3284+                             { 
3285+                                 "UltraServerId": "ultraserver-1", 
3286+                                 "InstanceCount": "2" 
3287+                             } 
3288+                         ] 
3289+                     } 
32523290        """ 
32533291        self .image_uri  =  image_uri 
32543292        self ._hyperparameters  =  hyperparameters .copy () if  hyperparameters  else  {}
@@ -3303,6 +3341,7 @@ def __init__(
33033341            enable_remote_debug = enable_remote_debug ,
33043342            enable_session_tag_chaining = enable_session_tag_chaining ,
33053343            training_plan = training_plan ,
3344+             instance_placement_config = instance_placement_config ,
33063345            ** kwargs ,
33073346        )
33083347
0 commit comments