1616
1717import pytest
1818import yaml
19- from mock import patch , Mock
19+ from mock import call , patch , Mock
2020
2121import sagemaker
2222from sagemaker .local .image import _SageMakerContainer
4040 'S3DataSource' : {
4141 'S3DataDistributionType' : 'FullyReplicated' ,
4242 'S3DataType' : 'S3Prefix' ,
43- 'S3Uri' : 's3://foo/bar '
43+ 'S3Uri' : 's3://my-own-bucket/prefix '
4444 }
4545 }
4646 }
@@ -54,12 +54,12 @@ def sagemaker_session():
5454 boto_mock .client ('sts' ).get_caller_identity .return_value = {'Account' : '123' }
5555 boto_mock .resource ('s3' ).Bucket (BUCKET_NAME ).objects .filter .return_value = []
5656
57- ims = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
57+ sms = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
5858
59- ims .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
60- ims .expand_role = Mock (return_value = EXPANDED_ROLE )
59+ sms .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
60+ sms .expand_role = Mock (return_value = EXPANDED_ROLE )
6161
62- return ims
62+ return sms
6363
6464
6565@patch ('sagemaker.local.local_session.LocalSession' )
@@ -181,16 +181,22 @@ def test_check_output():
181181@patch ('sagemaker.local.local_session.LocalSession' )
182182@patch ('sagemaker.local.image._execute_and_stream_output' )
183183@patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
184- def test_train (LocalSession , _execute_and_stream_output , _cleanup , tmpdir , sagemaker_session ):
184+ @patch ('sagemaker.local.image._SageMakerContainer._download_folder' )
185+ def test_train (_download_folder , _cleanup , _execute_and_stream_output , LocalSession , tmpdir , sagemaker_session ):
185186
187+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
186188 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
187- side_effect = [ str ( tmpdir . mkdir ( 'container-root' )), str ( tmpdir . mkdir ( 'data' ))] ):
189+ side_effect = directories ):
188190
189191 instance_count = 2
190192 image = 'my-image'
191193 sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
192194 sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
193195
196+ channel_dir = os .path .join (directories [1 ], 'b' )
197+ download_folder_calls = [call ('my-own-bucket' , 'prefix' , channel_dir )]
198+ _download_folder .assert_has_calls (download_folder_calls )
199+
194200 docker_compose_file = os .path .join (sagemaker_container .container_root , 'docker-compose.yaml' )
195201
196202 call_args = _execute_and_stream_output .call_args [0 ][0 ]
@@ -231,6 +237,36 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
231237 assert config ['services' ][h ]['command' ] == 'serve'
232238
233239
240+ @patch ('os.makedirs' )
241+ def test_download_folder (makedirs ):
242+ boto_mock = Mock (name = 'boto_session' )
243+ boto_mock .client ('sts' ).get_caller_identity .return_value = {'Account' : '123' }
244+
245+ session = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
246+
247+ train_data = Mock ()
248+ validation_data = Mock ()
249+
250+ train_data .bucket_name .return_value = BUCKET_NAME
251+ train_data .key = '/prefix/train/train_data.csv'
252+ validation_data .bucket_name .return_value = BUCKET_NAME
253+ validation_data .key = '/prefix/train/validation_data.csv'
254+
255+ s3_files = [train_data , validation_data ]
256+ boto_mock .resource ('s3' ).Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
257+
258+ obj_mock = Mock ()
259+ boto_mock .resource ('s3' ).Object .return_value = obj_mock
260+
261+ sagemaker_container = _SageMakerContainer ('local' , 2 , 'my-image' , sagemaker_session = session )
262+ sagemaker_container ._download_folder (BUCKET_NAME , '/prefix' , '/tmp' )
263+
264+ obj_mock .download_file .assert_called ()
265+ calls = [call (os .path .join ('/tmp' , 'train/train_data.csv' )),
266+ call (os .path .join ('/tmp' , 'train/validation_data.csv' ))]
267+ obj_mock .download_file .assert_has_calls (calls )
268+
269+
234270def test_ecr_login_non_ecr ():
235271 session_mock = Mock ()
236272 sagemaker .local .image ._ecr_login_if_needed (session_mock , 'ubuntu' )
0 commit comments