@@ -2125,7 +2125,6 @@ def test_generic_deploy_accelerator_type(sagemaker_session):
21252125 e .deploy (INSTANCE_COUNT , INSTANCE_TYPE , accelerator_type = ACCELERATOR_TYPE )
21262126
21272127 args = e .sagemaker_session .endpoint_from_production_variants .call_args [1 ]
2128- print (args )
21292128 assert args ["name" ].startswith (IMAGE_URI )
21302129 assert args ["production_variants" ][0 ]["AcceleratorType" ] == ACCELERATOR_TYPE
21312130 assert args ["production_variants" ][0 ]["InitialInstanceCount" ] == INSTANCE_COUNT
@@ -2182,7 +2181,6 @@ def test_local_mode(session_class, local_session_class):
21822181 session_class .return_value = session
21832182
21842183 e = Estimator (IMAGE_URI , ROLE , INSTANCE_COUNT , "local" )
2185- print (e .sagemaker_session .local_mode )
21862184 assert e .sagemaker_session .local_mode is True
21872185
21882186 e2 = Estimator (IMAGE_URI , ROLE , INSTANCE_COUNT , "local_gpu" )
@@ -2248,6 +2246,25 @@ def test_prepare_init_params_from_job_description_with_algorithm_training_job():
22482246 )
22492247
22502248
2249+ def test_prepare_init_params_from_job_description_with_spot_training ():
2250+ job_description = RETURNED_JOB_DESCRIPTION .copy ()
2251+ job_description ["EnableManagedSpotTraining" ] = True
2252+ job_description ["StoppingCondition" ] = {
2253+ "MaxRuntimeInSeconds" : 86400 ,
2254+ "MaxWaitTimeInSeconds" : 87000 ,
2255+ }
2256+
2257+ init_params = EstimatorBase ._prepare_init_params_from_job_description (
2258+ job_details = job_description
2259+ )
2260+
2261+ assert init_params ["role" ] == "arn:aws:iam::366:role/SageMakerRole"
2262+ assert init_params ["instance_count" ] == 1
2263+ assert init_params ["use_spot_instances" ]
2264+ assert init_params ["max_run" ] == 86400
2265+ assert init_params ["max_wait" ] == 87000
2266+
2267+
22512268def test_prepare_init_params_from_job_description_with_invalid_training_job ():
22522269
22532270 invalid_job_description = RETURNED_JOB_DESCRIPTION .copy ()
0 commit comments