Skip to content

Commit 524a8ce

Browse files
authored
Speed up unit test execution (#211)
Mock tar file creation on unit tests By creating a tar file out of the whole Data folder, several unit tests run really slow. Mock the tar file creation as it is not required.
1 parent d0b7384 commit 524a8ce

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.4.2dev
6+
========
7+
8+
* bug-fix: Unit Tests: Improve unit test runtime
9+
510
1.4.1
611
=====
712

@@ -18,6 +23,7 @@ CHANGELOG
1823
* feature: Analytics: Add functions for metrics in Training and Hyperparameter Tuning jobs
1924
* feature: Estimators: add support for tagging training jobs
2025

26+
2127
1.3.0
2228
=====
2329

tests/unit/test_tf_estimator.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020
from mock import patch, Mock
2121

22-
from sagemaker.fw_utils import create_image_uri
22+
from sagemaker.fw_utils import create_image_uri, UploadedCode
2323
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2424
from sagemaker.session import s3_input
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
@@ -207,19 +207,22 @@ def test_create_model(sagemaker_session, tf_version):
207207

208208
@patch('time.strftime', return_value=TIMESTAMP)
209209
@patch('time.time', return_value=TIME)
210-
def test_tf(time, strftime, sagemaker_session, tf_version):
210+
@patch('sagemaker.estimator.tar_and_upload_dir')
211+
@patch('sagemaker.model.tar_and_upload_dir')
212+
def test_tf(m_tar, e_tar, time, strftime, sagemaker_session, tf_version):
211213
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000,
212214
evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
213215
framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR)
214216

215217
inputs = 's3://mybucket/train'
216-
218+
s3_prefix = 's3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)
219+
e_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
220+
s3_prefix = 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)
221+
m_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE)
217222
tf.fit(inputs=inputs)
218223

219224
call_names = [c[0] for c in sagemaker_session.method_calls]
220225
assert call_names == ['train', 'logs_for_job']
221-
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
222-
assert boto_call_names == ['resource']
223226

224227
expected_train_args = _create_train_job(tf_version)
225228
expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs

0 commit comments

Comments
 (0)