|
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
15 | 15 | import pytest |
16 | | -from mock import Mock, patch |
| 16 | +from mock import Mock |
17 | 17 | from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator |
18 | 18 |
|
19 | 19 | MODEL_DATA = "s3://bucket/model.tar.gz" |
20 | 20 | MODEL_IMAGE = "mi" |
21 | 21 | ENTRY_POINT = "blah.py" |
22 | 22 |
|
23 | | -TIMESTAMP = "2017-11-06-14:14:15.671" |
24 | 23 | BUCKET_NAME = "mybucket" |
25 | 24 | INSTANCE_COUNT = 1 |
26 | 25 | INSTANCE_TYPE = "ml.c5.2xlarge" |
|
32 | 31 | DEFAULT_OUTPUT_PATH = "s3://{}/".format(BUCKET_NAME) |
33 | 32 | LOCAL_DATA_PATH = "file://data" |
34 | 33 | DEFAULT_MAX_CANDIDATES = 500 |
35 | | -DEFAULT_JOB_NAME = "sagemake-{}".format(TIMESTAMP) |
36 | 34 |
|
37 | 35 | JOB_NAME = "default-job-name" |
38 | 36 | JOB_NAME_2 = "banana-auto-ml-job" |
@@ -283,38 +281,34 @@ def test_auto_ml_additional_optional_params(sagemaker_session): |
283 | 281 | } |
284 | 282 |
|
285 | 283 |
|
286 | | -@patch("time.strftime", return_value=TIMESTAMP) |
287 | | -def test_auto_ml_default_fit(strftime, sagemaker_session): |
| 284 | +def test_auto_ml_default_fit(sagemaker_session): |
288 | 285 | auto_ml = AutoML( |
289 | 286 | role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session |
290 | 287 | ) |
291 | 288 | inputs = DEFAULT_S3_INPUT_DATA |
292 | 289 | auto_ml.fit(inputs) |
293 | 290 | sagemaker_session.auto_ml.assert_called_once() |
294 | 291 | _, args = sagemaker_session.auto_ml.call_args |
295 | | - assert args == { |
296 | | - "input_config": [ |
297 | | - { |
298 | | - "DataSource": { |
299 | | - "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
300 | | - }, |
301 | | - "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
302 | | - } |
303 | | - ], |
304 | | - "output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH}, |
305 | | - "auto_ml_job_config": { |
306 | | - "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
307 | | - "SecurityConfig": { |
308 | | - "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
| 292 | + assert args["input_config"] == [ |
| 293 | + { |
| 294 | + "DataSource": { |
| 295 | + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
309 | 296 | }, |
| 297 | + "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
| 298 | + } |
| 299 | + ] |
| 300 | + assert args["output_config"] == {"S3OutputPath": DEFAULT_OUTPUT_PATH} |
| 301 | + assert args["auto_ml_job_config"] == { |
| 302 | + "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
| 303 | + "SecurityConfig": { |
| 304 | + "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
310 | 305 | }, |
311 | | - "role": ROLE, |
312 | | - "job_name": DEFAULT_JOB_NAME, |
313 | | - "problem_type": None, |
314 | | - "job_objective": None, |
315 | | - "generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY, |
316 | | - "tags": None, |
317 | 306 | } |
| 307 | + assert args["role"] == ROLE |
| 308 | + assert args["problem_type"] is None |
| 309 | + assert args["job_objective"] is None |
| 310 | + assert args["generate_candidate_definitions_only"] == GENERATE_CANDIDATE_DEFINITIONS_ONLY |
| 311 | + assert args["tags"] is None |
318 | 312 |
|
319 | 313 |
|
320 | 314 | def test_auto_ml_local_input(sagemaker_session): |
|
0 commit comments