1
- # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1
+ # Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
4
# may not use this file except in compliance with the License. A copy of
@@ -303,10 +303,11 @@ def test_check_output():
303
303
304
304
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
305
305
@patch ('sagemaker.local.image._stream_output' , Mock ())
306
- @patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
306
+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
307
+ @patch ('sagemaker.local.image._SageMakerContainer.retrieve_artifacts' )
307
308
@patch ('sagemaker.local.data.get_data_source_instance' )
308
309
@patch ('subprocess.Popen' )
309
- def test_train (popen , get_data_source_instance , tmpdir , sagemaker_session ):
310
+ def test_train (popen , get_data_source_instance , retrieve_artifacts , cleanup , tmpdir , sagemaker_session ):
310
311
data_source = Mock ()
311
312
data_source .get_root_dir .return_value = 'foo'
312
313
get_data_source_instance .return_value = data_source
@@ -342,6 +343,9 @@ def test_train(popen, get_data_source_instance, tmpdir, sagemaker_session):
342
343
assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output' ))
343
344
assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output/data' ))
344
345
346
+ retrieve_artifacts .assert_called_once ()
347
+ cleanup .assert_called_once ()
348
+
345
349
346
350
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
347
351
@patch ('sagemaker.local.image._stream_output' , Mock ())
@@ -371,10 +375,11 @@ def test_train_with_hyperparameters_without_job_name(get_data_source_instance, t
371
375
372
376
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
373
377
@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
374
- @patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
378
+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
379
+ @patch ('sagemaker.local.image._SageMakerContainer.retrieve_artifacts' )
375
380
@patch ('sagemaker.local.data.get_data_source_instance' )
376
381
@patch ('subprocess.Popen' , Mock ())
377
- def test_train_error (get_data_source_instance , _stream_output , tmpdir , sagemaker_session ):
382
+ def test_train_error (get_data_source_instance , retrieve_artifacts , cleanup , _stream_output , tmpdir , sagemaker_session ):
378
383
data_source = Mock ()
379
384
data_source .get_root_dir .return_value = 'foo'
380
385
get_data_source_instance .return_value = data_source
@@ -391,6 +396,9 @@ def test_train_error(get_data_source_instance, _stream_output, tmpdir, sagemaker
391
396
392
397
assert 'this is expected' in str (e )
393
398
399
+ retrieve_artifacts .assert_called_once ()
400
+ cleanup .assert_called_once ()
401
+
394
402
395
403
@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
396
404
@patch ('sagemaker.local.image._stream_output' , Mock ())
0 commit comments