Skip to content

Commit 6248b7a

Browse files
authored
change: make Local Mode export artifacts even after failure (#746)
1 parent 5ffdcc0 commit 6248b7a

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

src/sagemaker/local/image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
137137
# which contains the exit code and append the command line to it.
138138
msg = "Failed to run: %s, %s" % (compose_command, str(e))
139139
raise RuntimeError(msg)
140+
finally:
141+
artifacts = self.retrieve_artifacts(compose_data, output_data_config, job_name)
140142

141-
artifacts = self.retrieve_artifacts(compose_data, output_data_config, job_name)
143+
# free up the training data directory as it may contain
144+
# lots of data downloaded from S3. This doesn't delete any local
145+
# data that was just mounted to the container.
146+
dirs_to_delete = [data_dir, shared_dir]
147+
self._cleanup(dirs_to_delete)
142148

143-
# free up the training data directory as it may contain
144-
# lots of data downloaded from S3. This doesn't delete any local
145-
# data that was just mounted to the container.
146-
dirs_to_delete = [data_dir, shared_dir]
147-
self._cleanup(dirs_to_delete)
148149
# Print our Job Complete line to have a similar experience to training on SageMaker where you
149150
# see this line at the end.
150151
print('===== Job Complete =====')

tests/integ/test_local_mode.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import tarfile
1617

1718
import boto3
1819
import numpy
@@ -318,6 +319,26 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
318319
mx.delete_endpoint()
319320

320321

322+
@pytest.mark.local_mode
323+
def test_mxnet_training_failure(sagemaker_local_session, mxnet_full_version, tmpdir):
324+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
325+
326+
mx = MXNet(entry_point=script_path,
327+
role='SageMakerRole',
328+
framework_version=mxnet_full_version,
329+
py_version=PYTHON_VERSION,
330+
train_instance_count=1,
331+
train_instance_type='local',
332+
sagemaker_session=sagemaker_local_session,
333+
output_path='file://{}'.format(tmpdir))
334+
335+
with pytest.raises(RuntimeError):
336+
mx.fit()
337+
338+
with tarfile.open(os.path.join(str(tmpdir), 'output.tar.gz')) as tar:
339+
tar.getmember('failure')
340+
341+
321342
@pytest.mark.local_mode
322343
def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_version):
323344
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')

tests/unit/test_image.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -303,10 +303,11 @@ def test_check_output():
303303

304304
@patch('sagemaker.local.local_session.LocalSession', Mock())
305305
@patch('sagemaker.local.image._stream_output', Mock())
306-
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
306+
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
307+
@patch('sagemaker.local.image._SageMakerContainer.retrieve_artifacts')
307308
@patch('sagemaker.local.data.get_data_source_instance')
308309
@patch('subprocess.Popen')
309-
def test_train(popen, get_data_source_instance, tmpdir, sagemaker_session):
310+
def test_train(popen, get_data_source_instance, retrieve_artifacts, cleanup, tmpdir, sagemaker_session):
310311
data_source = Mock()
311312
data_source.get_root_dir.return_value = 'foo'
312313
get_data_source_instance.return_value = data_source
@@ -342,6 +343,9 @@ def test_train(popen, get_data_source_instance, tmpdir, sagemaker_session):
342343
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output'))
343344
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output/data'))
344345

346+
retrieve_artifacts.assert_called_once()
347+
cleanup.assert_called_once()
348+
345349

346350
@patch('sagemaker.local.local_session.LocalSession', Mock())
347351
@patch('sagemaker.local.image._stream_output', Mock())
@@ -371,10 +375,11 @@ def test_train_with_hyperparameters_without_job_name(get_data_source_instance, t
371375

372376
@patch('sagemaker.local.local_session.LocalSession', Mock())
373377
@patch('sagemaker.local.image._stream_output', side_effect=RuntimeError('this is expected'))
374-
@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock())
378+
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
379+
@patch('sagemaker.local.image._SageMakerContainer.retrieve_artifacts')
375380
@patch('sagemaker.local.data.get_data_source_instance')
376381
@patch('subprocess.Popen', Mock())
377-
def test_train_error(get_data_source_instance, _stream_output, tmpdir, sagemaker_session):
382+
def test_train_error(get_data_source_instance, retrieve_artifacts, cleanup, _stream_output, tmpdir, sagemaker_session):
378383
data_source = Mock()
379384
data_source.get_root_dir.return_value = 'foo'
380385
get_data_source_instance.return_value = data_source
@@ -391,6 +396,9 @@ def test_train_error(get_data_source_instance, _stream_output, tmpdir, sagemaker
391396

392397
assert 'this is expected' in str(e)
393398

399+
retrieve_artifacts.assert_called_once()
400+
cleanup.assert_called_once()
401+
394402

395403
@patch('sagemaker.local.local_session.LocalSession', Mock())
396404
@patch('sagemaker.local.image._stream_output', Mock())

0 commit comments

Comments
 (0)