@@ -346,32 +346,66 @@ def test_deploy_creates_correct_session(local_session, session, tmpdir):
346346@patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
347347def test_deploy_update_endpoint (sagemaker_session , tmpdir ):
348348 model = DummyFrameworkModel (sagemaker_session , source_dir = tmpdir )
349+ model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = 1 , update_endpoint = True )
350+ sagemaker_session .create_endpoint_config .assert_called_with (
351+ name = model .name ,
352+ model_name = model .name ,
353+ initial_instance_count = INSTANCE_COUNT ,
354+ instance_type = INSTANCE_TYPE ,
355+ accelerator_type = None ,
356+ tags = None ,
357+ kms_key = None ,
358+ data_capture_config_dict = None ,
359+ )
360+ config_name = sagemaker_session .create_endpoint_config (
361+ name = model .name ,
362+ model_name = model .name ,
363+ initial_instance_count = INSTANCE_COUNT ,
364+ instance_type = INSTANCE_TYPE ,
365+ accelerator_type = ACCELERATOR_TYPE ,
366+ )
367+ sagemaker_session .update_endpoint .assert_called_with (model .name , config_name , wait = True )
368+ sagemaker_session .create_endpoint .assert_not_called ()
369+
370+
371+ @patch ("sagemaker.fw_utils.tar_and_upload_dir" , MagicMock ())
372+ def test_deploy_update_endpoint_optional_args (sagemaker_session , tmpdir ):
349373 endpoint_name = "endpoint-name"
374+ tags = [{"Key" : "Value" }]
375+ kms_key = "foo"
376+ data_capture_config = MagicMock ()
377+
378+ model = DummyFrameworkModel (sagemaker_session , source_dir = tmpdir )
350379 model .deploy (
351380 instance_type = INSTANCE_TYPE ,
352381 initial_instance_count = 1 ,
353- endpoint_name = endpoint_name ,
354382 update_endpoint = True ,
383+ endpoint_name = endpoint_name ,
355384 accelerator_type = ACCELERATOR_TYPE ,
385+ tags = tags ,
386+ kms_key = kms_key ,
387+ wait = False ,
388+ data_capture_config = data_capture_config ,
356389 )
357390 sagemaker_session .create_endpoint_config .assert_called_with (
358391 name = model .name ,
359392 model_name = model .name ,
360393 initial_instance_count = INSTANCE_COUNT ,
361394 instance_type = INSTANCE_TYPE ,
362395 accelerator_type = ACCELERATOR_TYPE ,
363- tags = None ,
364- kms_key = None ,
365- data_capture_config_dict = None ,
396+ tags = tags ,
397+ kms_key = kms_key ,
398+ data_capture_config_dict = data_capture_config . _to_request_dict () ,
366399 )
367400 config_name = sagemaker_session .create_endpoint_config (
368401 name = model .name ,
369402 model_name = model .name ,
370403 initial_instance_count = INSTANCE_COUNT ,
371404 instance_type = INSTANCE_TYPE ,
372405 accelerator_type = ACCELERATOR_TYPE ,
406+ wait = False ,
373407 )
374- sagemaker_session .update_endpoint .assert_called_with (endpoint_name , config_name )
408+ sagemaker_session .update_endpoint .assert_called_with (endpoint_name , config_name , wait = False )
375409 sagemaker_session .create_endpoint .assert_not_called ()
376410
377411
0 commit comments