@@ -310,9 +310,24 @@ def test_generate_tensorboard_url_domain_non_string():
310310@patch ("os.makedirs" )
311311def test_download_folder (makedirs ):
312312 boto_mock = Mock (name = "boto_session" )
313- boto_mock .client ("sts" ).get_caller_identity .return_value = {"Account" : "123" }
314-
315313 session = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
314+ s3_mock = boto_mock .resource ("s3" )
315+
316+ obj_mock = Mock ()
317+ s3_mock .Object .return_value = obj_mock
318+
319+ def obj_mock_download (path ):
320+ # Mock the S3 object to raise an error when the input to download_file
321+ # is a "folder"
322+ if path in ("/tmp/" , os .path .join ("/tmp" , "prefix" )):
323+ raise botocore .exceptions .ClientError (
324+ error_response = {"Error" : {"Code" : "404" , "Message" : "Not Found" }},
325+ operation_name = "HeadObject" ,
326+ )
327+ else :
328+ return Mock ()
329+
330+ obj_mock .download_file .side_effect = obj_mock_download
316331
317332 train_data = Mock ()
318333 validation_data = Mock ()
@@ -323,23 +338,20 @@ def test_download_folder(makedirs):
323338 validation_data .key = "prefix/train/validation_data.csv"
324339
325340 s3_files = [train_data , validation_data ]
326- boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
327-
328- obj_mock = Mock ()
329- boto_mock .resource ("s3" ).Object .return_value = obj_mock
341+ s3_mock .Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
330342
331343 # all the S3 mocks are set, the test itself begins now.
332344 sagemaker .utils .download_folder (BUCKET_NAME , "/prefix" , "/tmp" , session )
333345
334346 obj_mock .download_file .assert_called ()
335347 calls = [
336- call (os .path .join ("/tmp" , "train/ train_data.csv" )),
337- call (os .path .join ("/tmp" , "train/ validation_data.csv" )),
348+ call (os .path .join ("/tmp" , "train" , " train_data.csv" )),
349+ call (os .path .join ("/tmp" , "train" , " validation_data.csv" )),
338350 ]
339351 obj_mock .download_file .assert_has_calls (calls )
340352 obj_mock .reset_mock ()
341353
342- # Testing with a trailing slash for the prefix.
354+ # Test with a trailing slash for the prefix.
343355 sagemaker .utils .download_folder (BUCKET_NAME , "/prefix/" , "/tmp" , session )
344356 obj_mock .download_file .assert_called ()
345357 obj_mock .download_file .assert_has_calls (calls )
@@ -369,7 +381,7 @@ def test_download_folder_points_to_single_file(makedirs):
369381 obj_mock .download_file .assert_called ()
370382 calls = [call (os .path .join ("/tmp" , "train_data.csv" ))]
371383 obj_mock .download_file .assert_has_calls (calls )
372- assert boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .call_count == 1
384+ boto_mock .resource ("s3" ).Bucket (BUCKET_NAME ).objects .filter .assert_not_called ()
373385 obj_mock .reset_mock ()
374386
375387
0 commit comments