1- # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+ # Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License"). You
44# may not use this file except in compliance with the License. A copy of
@@ -303,10 +303,11 @@ def test_check_output():
303303
304304@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
305305@patch ('sagemaker.local.image._stream_output' , Mock ())
306- @patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
306+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
307+ @patch ('sagemaker.local.image._SageMakerContainer.retrieve_artifacts' )
307308@patch ('sagemaker.local.data.get_data_source_instance' )
308309@patch ('subprocess.Popen' )
309- def test_train (popen , get_data_source_instance , tmpdir , sagemaker_session ):
310+ def test_train (popen , get_data_source_instance , retrieve_artifacts , cleanup , tmpdir , sagemaker_session ):
310311 data_source = Mock ()
311312 data_source .get_root_dir .return_value = 'foo'
312313 get_data_source_instance .return_value = data_source
@@ -342,6 +343,9 @@ def test_train(popen, get_data_source_instance, tmpdir, sagemaker_session):
342343 assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output' ))
343344 assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output/data' ))
344345
346+ retrieve_artifacts .assert_called_once ()
347+ cleanup .assert_called_once ()
348+
345349
346350@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
347351@patch ('sagemaker.local.image._stream_output' , Mock ())
@@ -371,10 +375,11 @@ def test_train_with_hyperparameters_without_job_name(get_data_source_instance, t
371375
372376@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
373377@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
374- @patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
378+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
379+ @patch ('sagemaker.local.image._SageMakerContainer.retrieve_artifacts' )
375380@patch ('sagemaker.local.data.get_data_source_instance' )
376381@patch ('subprocess.Popen' , Mock ())
377- def test_train_error (get_data_source_instance , _stream_output , tmpdir , sagemaker_session ):
382+ def test_train_error (get_data_source_instance , retrieve_artifacts , cleanup , _stream_output , tmpdir , sagemaker_session ):
378383 data_source = Mock ()
379384 data_source .get_root_dir .return_value = 'foo'
380385 get_data_source_instance .return_value = data_source
@@ -391,6 +396,9 @@ def test_train_error(get_data_source_instance, _stream_output, tmpdir, sagemaker
391396
392397 assert 'this is expected' in str (e )
393398
399+ retrieve_artifacts .assert_called_once ()
400+ cleanup .assert_called_once ()
401+
394402
395403@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
396404@patch ('sagemaker.local.image._stream_output' , Mock ())
0 commit comments