|
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
15 | 15 | import pytest |
16 | | -from mock import Mock |
| 16 | +from mock import Mock, patch |
17 | 17 | from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator |
18 | 18 |
|
19 | 19 | MODEL_DATA = "s3://bucket/model.tar.gz" |
20 | 20 | MODEL_IMAGE = "mi" |
21 | 21 | ENTRY_POINT = "blah.py" |
22 | 22 |
|
| 23 | +TIMESTAMP = "2017-11-06-14:14:15.671" |
23 | 24 | BUCKET_NAME = "mybucket" |
24 | 25 | INSTANCE_COUNT = 1 |
25 | 26 | INSTANCE_TYPE = "ml.c5.2xlarge" |
|
31 | 32 | DEFAULT_OUTPUT_PATH = "s3://{}/".format(BUCKET_NAME) |
32 | 33 | LOCAL_DATA_PATH = "file://data" |
33 | 34 | DEFAULT_MAX_CANDIDATES = 500 |
| 35 | +DEFAULT_JOB_NAME = "automl-{}".format(TIMESTAMP) |
34 | 36 |
|
35 | 37 | JOB_NAME = "default-job-name" |
36 | 38 | JOB_NAME_2 = "banana-auto-ml-job" |
@@ -281,34 +283,38 @@ def test_auto_ml_additional_optional_params(sagemaker_session): |
281 | 283 | } |
282 | 284 |
|
283 | 285 |
|
284 | | -def test_auto_ml_default_fit(sagemaker_session): |
| 286 | +@patch("time.strftime", return_value=TIMESTAMP) |
| 287 | +def test_auto_ml_default_fit(strftime, sagemaker_session): |
285 | 288 | auto_ml = AutoML( |
286 | 289 | role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session |
287 | 290 | ) |
288 | 291 | inputs = DEFAULT_S3_INPUT_DATA |
289 | 292 | auto_ml.fit(inputs) |
290 | 293 | sagemaker_session.auto_ml.assert_called_once() |
291 | 294 | _, args = sagemaker_session.auto_ml.call_args |
292 | | - assert args["input_config"] == [ |
293 | | - { |
294 | | - "DataSource": { |
295 | | - "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
| 295 | + assert args == { |
| 296 | + "input_config": [ |
| 297 | + { |
| 298 | + "DataSource": { |
| 299 | + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
| 300 | + }, |
| 301 | + "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
| 302 | + } |
| 303 | + ], |
| 304 | + "output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH}, |
| 305 | + "auto_ml_job_config": { |
| 306 | + "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
| 307 | + "SecurityConfig": { |
| 308 | + "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
296 | 309 | }, |
297 | | - "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
298 | | - } |
299 | | - ] |
300 | | - assert args["output_config"] == {"S3OutputPath": DEFAULT_OUTPUT_PATH} |
301 | | - assert args["auto_ml_job_config"] == { |
302 | | - "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
303 | | - "SecurityConfig": { |
304 | | - "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
305 | 310 | }, |
| 311 | + "role": ROLE, |
| 312 | + "job_name": DEFAULT_JOB_NAME, |
| 313 | + "problem_type": None, |
| 314 | + "job_objective": None, |
| 315 | + "generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY, |
| 316 | + "tags": None, |
306 | 317 | } |
307 | | - assert args["role"] == ROLE |
308 | | - assert args["problem_type"] is None |
309 | | - assert args["job_objective"] is None |
310 | | - assert args["generate_candidate_definitions_only"] == GENERATE_CANDIDATE_DEFINITIONS_ONLY |
311 | | - assert args["tags"] is None |
312 | 318 |
|
313 | 319 |
|
314 | 320 | def test_auto_ml_local_input(sagemaker_session): |
|
0 commit comments