|
19 | 19 | import pytest
|
20 | 20 | from mock import patch, Mock
|
21 | 21 |
|
22 |
| -from sagemaker.fw_utils import create_image_uri |
| 22 | +from sagemaker.fw_utils import create_image_uri, UploadedCode |
23 | 23 | from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
|
24 | 24 | from sagemaker.session import s3_input
|
25 | 25 | from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
|
@@ -207,19 +207,22 @@ def test_create_model(sagemaker_session, tf_version):
|
207 | 207 |
|
208 | 208 | @patch('time.strftime', return_value=TIMESTAMP)
|
209 | 209 | @patch('time.time', return_value=TIME)
|
210 |
| -def test_tf(time, strftime, sagemaker_session, tf_version): |
| 210 | +@patch('sagemaker.estimator.tar_and_upload_dir') |
| 211 | +@patch('sagemaker.model.tar_and_upload_dir') |
| 212 | +def test_tf(m_tar, e_tar, time, strftime, sagemaker_session, tf_version): |
211 | 213 | tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000,
|
212 | 214 | evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
|
213 | 215 | framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR)
|
214 | 216 |
|
215 | 217 | inputs = 's3://mybucket/train'
|
216 |
| - |
| 218 | + s3_prefix = 's3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME) |
| 219 | + e_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE) |
| 220 | + s3_prefix = 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME) |
| 221 | + m_tar.return_value = UploadedCode(s3_prefix=s3_prefix, script_name=SCRIPT_FILE) |
217 | 222 | tf.fit(inputs=inputs)
|
218 | 223 |
|
219 | 224 | call_names = [c[0] for c in sagemaker_session.method_calls]
|
220 | 225 | assert call_names == ['train', 'logs_for_job']
|
221 |
| - boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] |
222 |
| - assert boto_call_names == ['resource'] |
223 | 226 |
|
224 | 227 | expected_train_args = _create_train_job(tf_version)
|
225 | 228 | expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs
|
|
0 commit comments