| 
 | 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  | 
 | 2 | +#  | 
 | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You  | 
 | 4 | +# may not use this file except in compliance with the License. A copy of  | 
 | 5 | +# the License is located at  | 
 | 6 | +#  | 
 | 7 | +#     http://aws.amazon.com/apache2.0/  | 
 | 8 | +#  | 
 | 9 | +# or in the "license" file accompanying this file. This file is  | 
 | 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF  | 
 | 11 | +# ANY KIND, either express or implied. See the License for the specific  | 
 | 12 | +# language governing permissions and limitations under the License.  | 
 | 13 | +"""The module provides helper function for Batch Submit/Describe/Terminal job APIs."""  | 
 | 14 | +from __future__ import absolute_import  | 
 | 15 | + | 
 | 16 | +import json  | 
 | 17 | +from typing import List, Dict, Optional  | 
 | 18 | +from sagemaker.aws_batch.constants import (  | 
 | 19 | +    SAGEMAKER_TRAINING,  | 
 | 20 | +    DEFAULT_TIMEOUT,  | 
 | 21 | +    DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG,  | 
 | 22 | +)  | 
 | 23 | +from sagemaker.aws_batch.boto_client import get_batch_boto_client  | 
 | 24 | + | 
 | 25 | + | 
 | 26 | +def submit_service_job(  | 
 | 27 | +    training_payload: Dict,  | 
 | 28 | +    job_name: str,  | 
 | 29 | +    job_queue: str,  | 
 | 30 | +    retry_config: Optional[Dict] = None,  | 
 | 31 | +    scheduling_priority: Optional[int] = None,  | 
 | 32 | +    timeout: Optional[Dict] = None,  | 
 | 33 | +    share_identifier: Optional[str] = None,  | 
 | 34 | +    tags: Optional[Dict] = None,  | 
 | 35 | +) -> Dict:  | 
 | 36 | +    """Batch submit_service_job API helper function.  | 
 | 37 | +
  | 
 | 38 | +    Args:  | 
 | 39 | +        training_payload: a dict containing a dict of arguments for Training job.  | 
 | 40 | +        job_name: Batch job name.  | 
 | 41 | +        job_queue: Batch job queue ARN.  | 
 | 42 | +        retry_config: Batch job retry configuration.  | 
 | 43 | +        scheduling_priority: An integer representing scheduling priority.  | 
 | 44 | +        timeout: Set with value of timeout if specified, else default to 1 day.  | 
 | 45 | +        share_identifier: value of shareIdentifier if specified.  | 
 | 46 | +        tags: A dict of string to string representing Batch tags.  | 
 | 47 | +
  | 
 | 48 | +    Returns:  | 
 | 49 | +        A dict containing jobArn, jobName and jobId.  | 
 | 50 | +    """  | 
 | 51 | +    if timeout is None:  | 
 | 52 | +        timeout = DEFAULT_TIMEOUT  | 
 | 53 | +    client = get_batch_boto_client()  | 
 | 54 | +    training_payload_tags = training_payload.pop("Tags", None)  | 
 | 55 | +    payload = {  | 
 | 56 | +        "jobName": job_name,  | 
 | 57 | +        "jobQueue": job_queue,  | 
 | 58 | +        "retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG,  | 
 | 59 | +        "serviceJobType": SAGEMAKER_TRAINING,  | 
 | 60 | +        "serviceRequestPayload": json.dumps(training_payload),  | 
 | 61 | +        "timeoutConfig": timeout,  | 
 | 62 | +    }  | 
 | 63 | +    if retry_config:  | 
 | 64 | +        payload["retryStrategy"] = retry_config  | 
 | 65 | +    if scheduling_priority:  | 
 | 66 | +        payload["schedulingPriority"] = scheduling_priority  | 
 | 67 | +    if share_identifier:  | 
 | 68 | +        payload["shareIdentifier"] = share_identifier  | 
 | 69 | +    if tags or training_payload_tags:  | 
 | 70 | +        payload["tags"] = __merge_tags(tags, training_payload_tags)  | 
 | 71 | +    return client.submit_service_job(**payload)  | 
 | 72 | + | 
 | 73 | + | 
 | 74 | +def describe_service_job(job_id: str) -> Dict:  | 
 | 75 | +    """Batch describe_service_job API helper function.  | 
 | 76 | +
  | 
 | 77 | +    Args:  | 
 | 78 | +        job_id: Job ID used.  | 
 | 79 | +
  | 
 | 80 | +    Returns: a dict. See the sample below  | 
 | 81 | +    {  | 
 | 82 | +        'attempts': [  | 
 | 83 | +            {  | 
 | 84 | +                'serviceResourceId': {  | 
 | 85 | +                    'name': 'string',  | 
 | 86 | +                    'value': 'string'  | 
 | 87 | +                },  | 
 | 88 | +                'startedAt': 123,  | 
 | 89 | +                'stoppedAt': 123,  | 
 | 90 | +                'statusReason': 'string'  | 
 | 91 | +            },  | 
 | 92 | +        ],  | 
 | 93 | +        'createdAt': 123,  | 
 | 94 | +        'isTerminated': True|False,  | 
 | 95 | +        'jobArn': 'string',  | 
 | 96 | +        'jobId': 'string',  | 
 | 97 | +        'jobName': 'string',  | 
 | 98 | +        'jobQueue': 'string',  | 
 | 99 | +        'retryStrategy': {  | 
 | 100 | +            'attempts': 123  | 
 | 101 | +        },  | 
 | 102 | +        'schedulingPriority': 123,  | 
 | 103 | +        'serviceRequestPayload': 'string',  | 
 | 104 | +        'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING',  | 
 | 105 | +        'shareIdentifier': 'string',  | 
 | 106 | +        'startedAt': 123,  | 
 | 107 | +        'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',  | 
 | 108 | +        'statusReason': 'string',  | 
 | 109 | +        'stoppedAt': 123,  | 
 | 110 | +        'tags': {  | 
 | 111 | +            'string': 'string'  | 
 | 112 | +        },  | 
 | 113 | +        'timeout': {  | 
 | 114 | +            'attemptDurationSeconds': 123  | 
 | 115 | +        }  | 
 | 116 | +    }  | 
 | 117 | +    """  | 
 | 118 | +    client = get_batch_boto_client()  | 
 | 119 | +    return client.describe_service_job(jobId=job_id)  | 
 | 120 | + | 
 | 121 | + | 
 | 122 | +def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict:  | 
 | 123 | +    """Batch terminate_service_job API helper function.  | 
 | 124 | +
  | 
 | 125 | +    Args:  | 
 | 126 | +        job_id: Job ID  | 
 | 127 | +        reason: A string representing terminate reason.  | 
 | 128 | +
  | 
 | 129 | +    Returns: an empty dict  | 
 | 130 | +    """  | 
 | 131 | +    client = get_batch_boto_client()  | 
 | 132 | +    return client.terminate_service_job(jobId=job_id, reason=reason)  | 
 | 133 | + | 
 | 134 | + | 
 | 135 | +def list_service_job(  | 
 | 136 | +    job_queue: str,  | 
 | 137 | +    job_status: Optional[str] = None,  | 
 | 138 | +    filters: Optional[List] = None,  | 
 | 139 | +    next_token: Optional[str] = None,  | 
 | 140 | +) -> Dict:  | 
 | 141 | +    """Batch list_service_job API helper function.  | 
 | 142 | +
  | 
 | 143 | +    Args:  | 
 | 144 | +        job_queue: Batch job queue ARN.  | 
 | 145 | +        job_status: Batch job status.  | 
 | 146 | +        filters: A list of Dict. Each contains a filter.  | 
 | 147 | +        next_token: Used to retrieve data in next page.  | 
 | 148 | +
  | 
 | 149 | +    Returns: A generator containing list results.  | 
 | 150 | +
  | 
 | 151 | +    """  | 
 | 152 | +    client = get_batch_boto_client()  | 
 | 153 | +    payload = {"jobQueue": job_queue}  | 
 | 154 | +    if filters:  | 
 | 155 | +        payload["filters"] = filters  | 
 | 156 | +    if next_token:  | 
 | 157 | +        payload["nextToken"] = next_token  | 
 | 158 | +    if job_status:  | 
 | 159 | +        payload["jobStatus"] = job_status  | 
 | 160 | +    part_of_jobs = client.list_service_jobs(**payload)  | 
 | 161 | +    next_token = part_of_jobs.get("nextToken")  | 
 | 162 | +    yield part_of_jobs  | 
 | 163 | +    if next_token:  | 
 | 164 | +        yield from list_service_job(job_queue, job_status, filters, next_token)  | 
 | 165 | + | 
 | 166 | + | 
 | 167 | +def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]:  | 
 | 168 | +    """Merges Batch and training payload tags.  | 
 | 169 | +
  | 
 | 170 | +    Returns a copy of Batch tags merged with training payload tags.  Training payload tags take  | 
 | 171 | +    precedence in the case of key conflicts.  | 
 | 172 | +
  | 
 | 173 | +    :param batch_tags: A dict of string to string representing Batch tags.  | 
 | 174 | +    :param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing  | 
 | 175 | +    training payload tags.  | 
 | 176 | +    :return: A dict of string to string representing batch tags merged with training tags.  | 
 | 177 | +    batch_tags is returned unaltered if training_tags is None or empty.  | 
 | 178 | +    """  | 
 | 179 | +    if not training_tags:  | 
 | 180 | +        return batch_tags  | 
 | 181 | + | 
 | 182 | +    training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags}  | 
 | 183 | +    batch_tags_copy = batch_tags.copy() if batch_tags else {}  | 
 | 184 | +    batch_tags_copy.update(training_tags_to_merge)  | 
 | 185 | + | 
 | 186 | +    return batch_tags_copy  | 
0 commit comments