Skip to content

Commit 907f923

Browse files
beniericgkatkovhaoxinwaJennaZhaojessicazhu3
authored
feature: AWS Batch for SageMaker Training jobs (#5249)
--------- Co-authored-by: Greg Katkov <[email protected]> Co-authored-by: haoxinwa <[email protected]> Co-authored-by: JennaZhao <[email protected]> Co-authored-by: Jessica Zhu <[email protected]> Co-authored-by: David Lindskog <[email protected]>
1 parent ed4fbe8 commit 907f923

26 files changed

+2354
-78
lines changed

src/sagemaker/aws_batch/__init__.py

Whitespace-only changes.
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The module provides helper function for Batch Submit/Describe/Terminal job APIs."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
from typing import List, Dict, Optional
18+
from sagemaker.aws_batch.constants import (
19+
SAGEMAKER_TRAINING,
20+
DEFAULT_TIMEOUT,
21+
DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG,
22+
)
23+
from sagemaker.aws_batch.boto_client import get_batch_boto_client
24+
25+
26+
def submit_service_job(
27+
training_payload: Dict,
28+
job_name: str,
29+
job_queue: str,
30+
retry_config: Optional[Dict] = None,
31+
scheduling_priority: Optional[int] = None,
32+
timeout: Optional[Dict] = None,
33+
share_identifier: Optional[str] = None,
34+
tags: Optional[Dict] = None,
35+
) -> Dict:
36+
"""Batch submit_service_job API helper function.
37+
38+
Args:
39+
training_payload: a dict containing a dict of arguments for Training job.
40+
job_name: Batch job name.
41+
job_queue: Batch job queue ARN.
42+
retry_config: Batch job retry configuration.
43+
scheduling_priority: An integer representing scheduling priority.
44+
timeout: Set with value of timeout if specified, else default to 1 day.
45+
share_identifier: value of shareIdentifier if specified.
46+
tags: A dict of string to string representing Batch tags.
47+
48+
Returns:
49+
A dict containing jobArn, jobName and jobId.
50+
"""
51+
if timeout is None:
52+
timeout = DEFAULT_TIMEOUT
53+
client = get_batch_boto_client()
54+
training_payload_tags = training_payload.pop("Tags", None)
55+
payload = {
56+
"jobName": job_name,
57+
"jobQueue": job_queue,
58+
"retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG,
59+
"serviceJobType": SAGEMAKER_TRAINING,
60+
"serviceRequestPayload": json.dumps(training_payload),
61+
"timeoutConfig": timeout,
62+
}
63+
if retry_config:
64+
payload["retryStrategy"] = retry_config
65+
if scheduling_priority:
66+
payload["schedulingPriority"] = scheduling_priority
67+
if share_identifier:
68+
payload["shareIdentifier"] = share_identifier
69+
if tags or training_payload_tags:
70+
payload["tags"] = __merge_tags(tags, training_payload_tags)
71+
return client.submit_service_job(**payload)
72+
73+
74+
def describe_service_job(job_id: str) -> Dict:
75+
"""Batch describe_service_job API helper function.
76+
77+
Args:
78+
job_id: Job ID used.
79+
80+
Returns: a dict. See the sample below
81+
{
82+
'attempts': [
83+
{
84+
'serviceResourceId': {
85+
'name': 'string',
86+
'value': 'string'
87+
},
88+
'startedAt': 123,
89+
'stoppedAt': 123,
90+
'statusReason': 'string'
91+
},
92+
],
93+
'createdAt': 123,
94+
'isTerminated': True|False,
95+
'jobArn': 'string',
96+
'jobId': 'string',
97+
'jobName': 'string',
98+
'jobQueue': 'string',
99+
'retryStrategy': {
100+
'attempts': 123
101+
},
102+
'schedulingPriority': 123,
103+
'serviceRequestPayload': 'string',
104+
'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING',
105+
'shareIdentifier': 'string',
106+
'startedAt': 123,
107+
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
108+
'statusReason': 'string',
109+
'stoppedAt': 123,
110+
'tags': {
111+
'string': 'string'
112+
},
113+
'timeout': {
114+
'attemptDurationSeconds': 123
115+
}
116+
}
117+
"""
118+
client = get_batch_boto_client()
119+
return client.describe_service_job(jobId=job_id)
120+
121+
122+
def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict:
123+
"""Batch terminate_service_job API helper function.
124+
125+
Args:
126+
job_id: Job ID
127+
reason: A string representing terminate reason.
128+
129+
Returns: an empty dict
130+
"""
131+
client = get_batch_boto_client()
132+
return client.terminate_service_job(jobId=job_id, reason=reason)
133+
134+
135+
def list_service_job(
136+
job_queue: str,
137+
job_status: Optional[str] = None,
138+
filters: Optional[List] = None,
139+
next_token: Optional[str] = None,
140+
) -> Dict:
141+
"""Batch list_service_job API helper function.
142+
143+
Args:
144+
job_queue: Batch job queue ARN.
145+
job_status: Batch job status.
146+
filters: A list of Dict. Each contains a filter.
147+
next_token: Used to retrieve data in next page.
148+
149+
Returns: A generator containing list results.
150+
151+
"""
152+
client = get_batch_boto_client()
153+
payload = {"jobQueue": job_queue}
154+
if filters:
155+
payload["filters"] = filters
156+
if next_token:
157+
payload["nextToken"] = next_token
158+
if job_status:
159+
payload["jobStatus"] = job_status
160+
part_of_jobs = client.list_service_jobs(**payload)
161+
next_token = part_of_jobs.get("nextToken")
162+
yield part_of_jobs
163+
if next_token:
164+
yield from list_service_job(job_queue, job_status, filters, next_token)
165+
166+
167+
def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]:
168+
"""Merges Batch and training payload tags.
169+
170+
Returns a copy of Batch tags merged with training payload tags. Training payload tags take
171+
precedence in the case of key conflicts.
172+
173+
:param batch_tags: A dict of string to string representing Batch tags.
174+
:param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing
175+
training payload tags.
176+
:return: A dict of string to string representing batch tags merged with training tags.
177+
batch_tags is returned unaltered if training_tags is None or empty.
178+
"""
179+
if not training_tags:
180+
return batch_tags
181+
182+
training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags}
183+
batch_tags_copy = batch_tags.copy() if batch_tags else {}
184+
batch_tags_copy.update(training_tags_to_merge)
185+
186+
return batch_tags_copy
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The file provides helper function for getting Batch boto client."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional
17+
import boto3
18+
19+
20+
def get_batch_boto_client(
21+
region: Optional[str] = None,
22+
endpoint: Optional[str] = None,
23+
) -> boto3.session.Session.client:
24+
"""Helper function for getting Batch boto3 client.
25+
26+
Args:
27+
region: Region specified
28+
endpoint: Batch API endpoint.
29+
30+
Returns: Batch boto3 client.
31+
32+
"""
33+
return boto3.client("batch", region_name=region, endpoint_url=endpoint)

src/sagemaker/aws_batch/constants.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The file defines constants used for Batch API helper functions."""
14+
15+
from __future__ import absolute_import
16+
17+
SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING"
18+
DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds.
19+
DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS}
20+
POLL_IN_SECONDS = 5
21+
JOB_STATUS_RUNNING = "RUNNING"
22+
JOB_STATUS_COMPLETED = "SUCCEEDED"
23+
JOB_STATUS_FAILED = "FAILED"
24+
DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = {
25+
"attempts": 1,
26+
"evaluateOnExit": [
27+
{
28+
"action": "RETRY",
29+
"onStatusReason": "Received status from SageMaker:InternalServerError: "
30+
"We encountered an internal error. Please try again.",
31+
},
32+
{"action": "EXIT", "onStatusReason": "*"},
33+
],
34+
}

src/sagemaker/aws_batch/exception.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The file Defines customized exception for Batch queueing"""
14+
from __future__ import absolute_import
15+
16+
17+
class NoTrainingJob(Exception):
18+
"""Define NoTrainingJob Exception.
19+
20+
It means no Training job has been created by AWS Batch service.
21+
"""
22+
23+
def __init__(self, value):
24+
super().__init__(value)
25+
self.value = value
26+
27+
def __str__(self):
28+
"""Convert Exception to string.
29+
30+
Returns: a String containing exception error messages.
31+
32+
"""
33+
return repr(self.value)
34+
35+
36+
class MissingRequiredArgument(Exception):
37+
"""Define MissingRequiredArgument exception.
38+
39+
It means some required arguments are missing.
40+
"""
41+
42+
def __init__(self, value):
43+
super().__init__(value)
44+
self.value = value
45+
46+
def __str__(self):
47+
"""Convert Exception to string.
48+
49+
Returns: a String containing exception error messages.
50+
51+
"""
52+
return repr(self.value)

0 commit comments

Comments
 (0)