@@ -761,6 +761,64 @@ def test_container_log_level(sagemaker_session):
761761 assert train_kwargs ['hyperparameters' ]['sagemaker_container_log_level' ] == '10'
762762
763763
764+ @patch ('sagemaker.utils' )
765+ def test_same_code_location_keeps_kms_key (utils , sagemaker_session ):
766+ fw = DummyFramework (entry_point = SCRIPT_PATH ,
767+ role = 'DummyRole' ,
768+ sagemaker_session = sagemaker_session ,
769+ train_instance_count = INSTANCE_COUNT ,
770+ train_instance_type = INSTANCE_TYPE ,
771+ output_kms_key = 'kms-key' )
772+
773+ fw .fit (wait = False )
774+
775+ extra_args = {'ServerSideEncryption' : 'aws:kms' , 'SSEKMSKeyId' : 'kms-key' }
776+ obj = sagemaker_session .boto_session .resource ('s3' ).Object
777+
778+ obj .assert_called_with ('mybucket' , '%s/source/sourcedir.tar.gz' % fw ._current_job_name )
779+
780+ obj ().upload_file .assert_called_with (utils .create_tar_file (), ExtraArgs = extra_args )
781+
782+
783+ @patch ('sagemaker.utils' )
784+ def test_different_code_location_kms_key (utils , sagemaker_session ):
785+ fw = DummyFramework (entry_point = SCRIPT_PATH ,
786+ role = 'DummyRole' ,
787+ sagemaker_session = sagemaker_session ,
788+ code_location = 's3://another-location' ,
789+ train_instance_count = INSTANCE_COUNT ,
790+ train_instance_type = INSTANCE_TYPE ,
791+ output_kms_key = 'kms-key' )
792+
793+ fw .fit (wait = False )
794+
795+ obj = sagemaker_session .boto_session .resource ('s3' ).Object
796+
797+ obj .assert_called_with ('another-location' , '%s/source/sourcedir.tar.gz' % fw ._current_job_name )
798+
799+ obj ().upload_file .assert_called_with (utils .create_tar_file (), ExtraArgs = None )
800+
801+
802+ @patch ('sagemaker.utils' )
803+ def test_default_code_location_uses_output_path (utils , sagemaker_session ):
804+ fw = DummyFramework (entry_point = SCRIPT_PATH ,
805+ role = 'DummyRole' ,
806+ sagemaker_session = sagemaker_session ,
807+ output_path = 's3://output_path' ,
808+ train_instance_count = INSTANCE_COUNT ,
809+ train_instance_type = INSTANCE_TYPE ,
810+ output_kms_key = 'kms-key' )
811+
812+ fw .fit (wait = False )
813+
814+ obj = sagemaker_session .boto_session .resource ('s3' ).Object
815+
816+ obj .assert_called_with ('output_path' , '%s/source/sourcedir.tar.gz' % fw ._current_job_name )
817+
818+ extra_args = {'ServerSideEncryption' : 'aws:kms' , 'SSEKMSKeyId' : 'kms-key' }
819+ obj ().upload_file .assert_called_with (utils .create_tar_file (), ExtraArgs = extra_args )
820+
821+
764822def test_wait_without_logs (sagemaker_session ):
765823 training_job = _TrainingJob (sagemaker_session , JOB_NAME )
766824
0 commit comments