diff --git a/src/sagemaker/aws_batch/__init__.py b/src/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/aws_batch/batch_api_helper.py b/src/sagemaker/aws_batch/batch_api_helper.py new file mode 100644 index 0000000000..4482a644ab --- /dev/null +++ b/src/sagemaker/aws_batch/batch_api_helper.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The module provides helper function for Batch Submit/Describe/Terminal job APIs.""" +from __future__ import absolute_import + +import json +from typing import List, Dict, Optional +from sagemaker.aws_batch.constants import ( + SAGEMAKER_TRAINING, + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, +) +from sagemaker.aws_batch.boto_client import get_batch_boto_client + + +def submit_service_job( + training_payload: Dict, + job_name: str, + job_queue: str, + retry_config: Optional[Dict] = None, + scheduling_priority: Optional[int] = None, + timeout: Optional[Dict] = None, + share_identifier: Optional[str] = None, + tags: Optional[Dict] = None, +) -> Dict: + """Batch submit_service_job API helper function. + + Args: + training_payload: a dict containing a dict of arguments for Training job. + job_name: Batch job name. + job_queue: Batch job queue ARN. + retry_config: Batch job retry configuration. + scheduling_priority: An integer representing scheduling priority. + timeout: Set with value of timeout if specified, else default to 1 day. + share_identifier: value of shareIdentifier if specified. + tags: A dict of string to string representing Batch tags. + + Returns: + A dict containing jobArn, jobName and jobId. + """ + if timeout is None: + timeout = DEFAULT_TIMEOUT + client = get_batch_boto_client() + training_payload_tags = training_payload.pop("Tags", None) + payload = { + "jobName": job_name, + "jobQueue": job_queue, + "retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + "serviceJobType": SAGEMAKER_TRAINING, + "serviceRequestPayload": json.dumps(training_payload), + "timeoutConfig": timeout, + } + if retry_config: + payload["retryStrategy"] = retry_config + if scheduling_priority: + payload["schedulingPriority"] = scheduling_priority + if share_identifier: + payload["shareIdentifier"] = share_identifier + if tags or training_payload_tags: + payload["tags"] = __merge_tags(tags, training_payload_tags) + return client.submit_service_job(**payload) + + +def describe_service_job(job_id: str) -> Dict: + """Batch describe_service_job API helper function. + + Args: + job_id: Job ID used. + + Returns: a dict. See the sample below + { + 'attempts': [ + { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + }, + 'startedAt': 123, + 'stoppedAt': 123, + 'statusReason': 'string' + }, + ], + 'createdAt': 123, + 'isTerminated': True|False, + 'jobArn': 'string', + 'jobId': 'string', + 'jobName': 'string', + 'jobQueue': 'string', + 'retryStrategy': { + 'attempts': 123 + }, + 'schedulingPriority': 123, + 'serviceRequestPayload': 'string', + 'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING', + 'shareIdentifier': 'string', + 'startedAt': 123, + 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', + 'statusReason': 'string', + 'stoppedAt': 123, + 'tags': { + 'string': 'string' + }, + 'timeout': { + 'attemptDurationSeconds': 123 + } + } + """ + client = get_batch_boto_client() + return client.describe_service_job(jobId=job_id) + + +def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict: + """Batch terminate_service_job API helper function. + + Args: + job_id: Job ID + reason: A string representing terminate reason. + + Returns: an empty dict + """ + client = get_batch_boto_client() + return client.terminate_service_job(jobId=job_id, reason=reason) + + +def list_service_job( + job_queue: str, + job_status: Optional[str] = None, + filters: Optional[List] = None, + next_token: Optional[str] = None, +) -> Dict: + """Batch list_service_job API helper function. + + Args: + job_queue: Batch job queue ARN. + job_status: Batch job status. + filters: A list of Dict. Each contains a filter. + next_token: Used to retrieve data in next page. + + Returns: A generator containing list results. + + """ + client = get_batch_boto_client() + payload = {"jobQueue": job_queue} + if filters: + payload["filters"] = filters + if next_token: + payload["nextToken"] = next_token + if job_status: + payload["jobStatus"] = job_status + part_of_jobs = client.list_service_jobs(**payload) + next_token = part_of_jobs.get("nextToken") + yield part_of_jobs + if next_token: + yield from list_service_job(job_queue, job_status, filters, next_token) + + +def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]: + """Merges Batch and training payload tags. + + Returns a copy of Batch tags merged with training payload tags. Training payload tags take + precedence in the case of key conflicts. + + :param batch_tags: A dict of string to string representing Batch tags. + :param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing + training payload tags. + :return: A dict of string to string representing batch tags merged with training tags. + batch_tags is returned unaltered if training_tags is None or empty. + """ + if not training_tags: + return batch_tags + + training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags} + batch_tags_copy = batch_tags.copy() if batch_tags else {} + batch_tags_copy.update(training_tags_to_merge) + + return batch_tags_copy diff --git a/src/sagemaker/aws_batch/boto_client.py b/src/sagemaker/aws_batch/boto_client.py new file mode 100644 index 0000000000..87f3486887 --- /dev/null +++ b/src/sagemaker/aws_batch/boto_client.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file provides helper function for getting Batch boto client.""" +from __future__ import absolute_import + +from typing import Optional +import boto3 + + +def get_batch_boto_client( + region: Optional[str] = None, + endpoint: Optional[str] = None, +) -> boto3.session.Session.client: + """Helper function for getting Batch boto3 client. + + Args: + region: Region specified + endpoint: Batch API endpoint. + + Returns: Batch boto3 client. + + """ + return boto3.client("batch", region_name=region, endpoint_url=endpoint) diff --git a/src/sagemaker/aws_batch/constants.py b/src/sagemaker/aws_batch/constants.py new file mode 100644 index 0000000000..ee41d3a413 --- /dev/null +++ b/src/sagemaker/aws_batch/constants.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file defines constants used for Batch API helper functions.""" + +from __future__ import absolute_import + +SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING" +DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds. +DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS} +POLL_IN_SECONDS = 5 +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/src/sagemaker/aws_batch/exception.py b/src/sagemaker/aws_batch/exception.py new file mode 100644 index 0000000000..94318bbce4 --- /dev/null +++ b/src/sagemaker/aws_batch/exception.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file Defines customized exception for Batch queueing""" +from __future__ import absolute_import + + +class NoTrainingJob(Exception): + """Define NoTrainingJob Exception. + + It means no Training job has been created by AWS Batch service. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) + + +class MissingRequiredArgument(Exception): + """Define MissingRequiredArgument exception. + + It means some required arguments are missing. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) diff --git a/src/sagemaker/aws_batch/training_queue.py b/src/sagemaker/aws_batch/training_queue.py new file mode 100644 index 0000000000..b540fad0a9 --- /dev/null +++ b/src/sagemaker/aws_batch/training_queue.py @@ -0,0 +1,212 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define Queue class for AWS Batch service""" +from __future__ import absolute_import + +from typing import Dict, Optional, List, Union +import logging +from sagemaker.estimator import EstimatorBase, _TrainingJob +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from .training_queued_job import TrainingQueuedJob +from .batch_api_helper import submit_service_job, list_service_job +from .exception import MissingRequiredArgument +from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING + + +class TrainingQueue: + """TrainingQueue class for AWS Batch service + + With this class, customers are able to create a new queue and submit jobs to AWS Batch Service. + """ + + def __init__(self, queue_name: str): + self.queue_name = queue_name + + def submit( + self, + training_job: Union[EstimatorBase, ModelTrainer], + inputs, + job_name: Optional[str] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> TrainingQueuedJob: + """Submit a queued job and return a QueuedJob object. + + Args: + training_job: Training job EstimatorBase or ModelTrainer object. + inputs: Training job inputs. + job_name: Batch job name. + retry_config: Retry configuration for Batch job. + priority: Scheduling priority for Batch job. + share_identifier: Share identifier for Batch job. + timeout: Timeout configuration for Batch job. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a TrainingQueuedJob object with Batch job ARN and job name. + + """ + if not isinstance(training_job, (EstimatorBase, ModelTrainer)): + raise TypeError( + "training_job must be an instance of EstimatorBase or ModelTrainer, " + f"but got {type(training_job)}" + ) + + training_payload = {} + if isinstance(training_job, EstimatorBase): + if experiment_config is None: + experiment_config = {} + training_job.prepare_workflow_for_training(job_name) + training_args = _TrainingJob.get_train_args(training_job, inputs, experiment_config) + training_payload = training_job.sagemaker_session.get_train_request(**training_args) + else: + if training_job.training_mode != Mode.SAGEMAKER_TRAINING_JOB: + raise ValueError( + "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB" + ) + if experiment_config is not None: + logging.warning( + "ExperimentConfig is not supported for ModelTrainer. " + "It will be ignored when submitting the job." + ) + training_payload = training_job._create_training_job_args( + input_data_config=inputs, boto3=True + ) + + if timeout is None: + timeout = DEFAULT_TIMEOUT + if job_name is None: + job_name = training_payload["TrainingJobName"] + + resp = submit_service_job( + training_payload, + job_name, + self.queue_name, + retry_config, + priority, + timeout, + share_identifier, + tags, + ) + if "jobArn" not in resp or "jobName" not in resp: + raise MissingRequiredArgument( + "jobArn or jobName is missing in response from Batch submit_service_job API" + ) + return TrainingQueuedJob(resp["jobArn"], resp["jobName"]) + + def map( + self, + training_job: Union[EstimatorBase, ModelTrainer], + inputs, + job_names: Optional[List[str]] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> List[TrainingQueuedJob]: + """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects. + + Args: + training_job: Training job EstimatorBase or ModelTrainer object. + inputs: List of Training job inputs. + job_names: List of Batch job names. + retry_config: Retry config for the Batch jobs. + priority: Scheduling priority for the Batch jobs. + share_identifier: Share identifier for the Batch jobs. + timeout: Timeout configuration for the Batch jobs. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name. + + """ + if experiment_config is None: + experiment_config = {} + + if job_names is not None: + if len(job_names) != len(inputs): + raise ValueError( + "When specified, the number of job names must match the number of inputs" + ) + else: + job_names = [None] * len(inputs) + + queued_batch_job_list = [] + for index, value in enumerate(inputs): + queued_batch_job = self.submit( + training_job, + value, + job_names[index], + retry_config, + priority, + share_identifier, + timeout, + tags, + experiment_config, + ) + queued_batch_job_list.append(queued_batch_job) + + return queued_batch_job_list + + def list_jobs( + self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING + ) -> List[TrainingQueuedJob]: + """List Batch jobs according to job_name or status. + + Args: + job_name: Batch job name. + status: Batch job status. + + Returns: A list of QueuedJob. + + """ + filters = None + if job_name: + filters = [{"name": "JOB_NAME", "values": [job_name]}] + status = None # job_status is ignored when job_name is specified. + jobs_to_return = [] + next_token = None + for job_result_dict in list_service_job(self.queue_name, status, filters, next_token): + for job_result in job_result_dict.get("jobSummaryList", []): + if "jobArn" in job_result and "jobName" in job_result: + jobs_to_return.append( + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"]) + ) + else: + logging.warning("Missing JobArn or JobName in Batch ListJobs API") + continue + return jobs_to_return + + def get_job(self, job_name): + """Get a Batch job according to job_name. + + Args: + job_name: Batch job name. + + Returns: The QueuedJob with name matching job_name. + + """ + jobs_to_return = self.list_jobs(job_name) + if len(jobs_to_return) == 0: + raise ValueError(f"Cannot find job: {job_name}") + return jobs_to_return[0] diff --git a/src/sagemaker/aws_batch/training_queued_job.py b/src/sagemaker/aws_batch/training_queued_job.py new file mode 100644 index 0000000000..6bb42c3c61 --- /dev/null +++ b/src/sagemaker/aws_batch/training_queued_job.py @@ -0,0 +1,217 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define QueuedJob class for AWS Batch service""" +from __future__ import absolute_import + +import logging +import time +import asyncio +from typing import Optional, Dict +import nest_asyncio +from sagemaker.estimator import Estimator +from .batch_api_helper import terminate_service_job, describe_service_job +from .exception import NoTrainingJob, MissingRequiredArgument +from ..utils import get_training_job_name_from_training_job_arn +from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS + +logging.basicConfig( + format="%(asctime)s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" +) + + +class TrainingQueuedJob: + """TrainingQueuedJob class for AWS Batch service. + + With this class, customers are able to attach the latest training job to an estimator. + """ + + def __init__(self, job_arn: str, job_name: str): + self.job_arn = job_arn + self.job_name = job_name + self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"} + + def get_estimator(self) -> Estimator: + """Attach the latest training job to an estimator and return. + + Returns: an Estimator instance. + + """ + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + if self._training_job_created(job_status): + if "latestAttempt" not in describe_resp: + raise MissingRequiredArgument("No LatestAttempt in describe call") + new_training_job_name = _get_new_training_job_name_from_latest_attempt( + describe_resp["latestAttempt"] + ) + output_estimator = _construct_estimator_from_training_job_name(new_training_job_name) + _remove_system_tags_in_place_in_estimator_object(output_estimator) + return output_estimator + + _output_attempt_history(describe_resp) + raise NoTrainingJob("No Training job created. Job is still waiting in queue") + + def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: + """Terminate Batch job. + + Args: + reason: Reason for terminating a job. + + Returns: None + + """ + terminate_service_job(self.job_arn, reason) + + def describe(self) -> Dict: + """Describe Batch job. + + Returns: A dict which includes job parameters, job status, attempts and so on. + + """ + return describe_service_job(self.job_arn) + + def _training_job_created(self, status: str) -> bool: + """Return True if a Training job has been created + + Args: + status: Job status returned from Batch API. + + Returns: a boolean indicating whether a Training job has been created. + + """ + return status not in self._no_training_job_status + + def result(self, timeout: int = None) -> Dict: + """Fetch the terminal result of the Batch job. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict. + + """ + nest_asyncio.apply() + loop = asyncio.get_event_loop() + task = loop.create_task(self.fetch_job_results(timeout)) + resp = loop.run_until_complete(task) + return resp + + async def fetch_job_results(self, timeout: int = None) -> Dict: + """Async method that waits for the Batch job to complete or until timeout. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict, or an Error. + + """ + self.wait(timeout) + + describe_resp = self.describe() + if describe_resp.get("status", "") == JOB_STATUS_COMPLETED: + return describe_resp + if describe_resp.get("status", "") == JOB_STATUS_FAILED: + raise RuntimeError(describe_resp["statusReason"]) + raise TimeoutError("Reached timeout before the Batch job reached a terminal status") + + def wait(self, timeout: int = None) -> Dict: + """Wait for the Batch job to finish. + + This method blocks on the job completing for up to the timeout value (if specified). + If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: The last describe_service_job response for the Batch job. + """ + request_end_time = time.time() + timeout if timeout else None + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + while not job_completed: + if timeout and time.time() > request_end_time: + logging.info( + "Timeout exceeded: %d seconds elapsed. Returning current results", timeout + ) + break + if job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED): + break + + time.sleep(POLL_IN_SECONDS) + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + return describe_resp + + +def _construct_estimator_from_training_job_name(training_job_name: str) -> Estimator: + """Build Estimator instance from payload. + + Args: + training_job_name: Training job name. + + Returns: an Estimator instance. + + """ + return Estimator.attach(training_job_name) + + +def _output_attempt_history(describe_resp: Dict) -> None: + """Print attempt history if no Training job created. + + Args: + describe_resp: Describe response from Batch API. + + Returns: None + + """ + has_seen_status_reason = False + for i, attempt_dict in enumerate(describe_resp.get("attempts", [])): + if "statusReason" in attempt_dict: + logging.info("Attempt %d - %s", i + 1, attempt_dict["statusReason"]) + has_seen_status_reason = True + if not has_seen_status_reason: + logging.info("No attempts found or no statusReason found.") + + +def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str: + """Extract new Training job name from latest attempt in Batch Describe response. + + Args: + latest_attempt: a Dict containing Training job arn. + + Returns: new Training job name or None if not found. + + """ + training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None) + return get_training_job_name_from_training_job_arn(training_job_arn) + + +def _remove_system_tags_in_place_in_estimator_object(estimator: Estimator) -> None: + """Remove system tags in place. + + Args: + estimator: input Estimator object. + + Returns: None. Remove system tags in place. + + """ + new_tags = [] + for tag_dict in estimator.tags: + if not tag_dict.get("Key", "").startswith("aws:"): + new_tags.append(tag_dict) + estimator.tags = new_tags diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9b4beae5c4..0055416327 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2546,6 +2546,11 @@ def start_new(cls, estimator, inputs, experiment_config): return cls(estimator.sagemaker_session, estimator._current_job_name) + @classmethod + def get_train_args(cls, estimator, inputs, experiment_config): + """A public function which is same as _get_train_args function.""" + return cls._get_train_args(estimator, inputs, experiment_config) + @classmethod def _get_train_args(cls, estimator, inputs, experiment_config): """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator. diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 24b7922895..828c5da198 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -27,6 +27,7 @@ from sagemaker_core.resources import TrainingJob from sagemaker_core import shapes from sagemaker_core.shapes import AlgorithmSpecification +from sagemaker_core.main.utils import serialize from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call @@ -252,6 +253,7 @@ class ModelTrainer(BaseModel): _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ "role", @@ -380,6 +382,8 @@ def __del__(self): if hasattr(self, "__pydantic_fields_set__"): if self._temp_recipe_train_dir is not None: self._temp_recipe_train_dir.cleanup() + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -590,28 +594,25 @@ def _fetch_bucket_name_and_prefix(session: Session) -> str: return f"{session.default_bucket()}/{session.default_bucket_prefix}" return session.default_bucket() - @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") - @validate_call - def train( + def _create_training_job_args( self, input_data_config: Optional[List[Union[Channel, InputData]]] = None, - wait: Optional[bool] = True, - logs: Optional[bool] = True, - ): - """Train a model using AWS SageMaker. + boto3: bool = False, + ) -> Dict[str, Any]: + """Create the training job arguments. Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object. - wait (Optional[bool]): - Whether to wait for the training job to complete before returning. - Defaults to True. - logs (Optional[bool]): - Whether to display the training container logs while training. - Defaults to True. + boto3 (bool): Whether to return the arguments in boto3 format. Defaults to False. + By default, the arguments are returned in the format used by the SageMaker Core. + + Returns: + Dict[str, Any]: The training job arguments. """ self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) @@ -672,16 +673,18 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) + self._temp_code_dir = TemporaryDirectory( + prefix=os.path.join(self.local_container_root + "/") + ) else: - tmp_dir = TemporaryDirectory() + self._temp_code_dir = TemporaryDirectory() # Copy everything under container_drivers/ to a temporary directory - shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True) # If distributed is provided, overwrite code under /drivers if self.distributed: distributed_driver_dir = self.distributed.driver_dir - driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") + driver_dir = os.path.join(self._temp_code_dir.name, "distributed_drivers") shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code @@ -696,7 +699,7 @@ def train( final_input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=tmp_dir, + tmp_dir=self._temp_code_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -705,13 +708,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=self._temp_code_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=self._temp_code_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=tmp_dir.name, + data_source=self._temp_code_dir.name, key_prefix=input_data_key_prefix, ignore_patterns=self.source_code.ignore_patterns, ) @@ -738,40 +741,93 @@ def train( resource_config = self.compute._to_resource_config() vpc_config = self.networking._to_vpc_config() if self.networking else None - if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: - training_job = TrainingJob.create( - training_job_name=current_training_job_name, - algorithm_specification=algorithm_specification, - hyper_parameters=string_hyper_parameters, - input_data_config=final_input_data_config, - resource_config=resource_config, - vpc_config=vpc_config, - # Public Instance Attributes - session=self.sagemaker_session.boto_session, - role_arn=self.role, - tags=self.tags, - stopping_condition=self.stopping_condition, - output_data_config=self.output_data_config, - checkpoint_config=self.checkpoint_config, - environment=self.environment, - enable_managed_spot_training=self.compute.enable_managed_spot_training, - enable_inter_container_traffic_encryption=( - self.networking.enable_inter_container_traffic_encryption - if self.networking - else None - ), - enable_network_isolation=( - self.networking.enable_network_isolation if self.networking else None - ), - # Private Instance Attributes - remote_debug_config=self._remote_debug_config, - tensor_board_output_config=self._tensorboard_output_config, - retry_strategy=self._retry_strategy, - infra_check_config=self._infra_check_config, - session_chaining_config=self._session_chaining_config, + if boto3: + args = {} + args["TrainingJobName"] = current_training_job_name + args["AlgorithmSpecification"] = algorithm_specification + args["HyperParameters"] = string_hyper_parameters + args["InputDataConfig"] = final_input_data_config + args["ResourceConfig"] = resource_config + args["VpcConfig"] = vpc_config + args["RoleArn"] = self.role + args["Tags"] = self.tags + args["StoppingCondition"] = self.stopping_condition + args["OutputDataConfig"] = self.output_data_config + args["CheckpointConfig"] = self.checkpoint_config + args["Environment"] = self.environment + args["EnableManagedSotTraining"] = self.compute.enable_managed_spot_training + args["EnableInterContainerTrafficEncryption"] = ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None ) - self._latest_training_job = training_job + args["EnableNetworkIsolation"] = ( + self.networking.enable_network_isolation if self.networking else None + ) + args["RemoteDebugConfig"] = self._remote_debug_config + args["TensorBoardOutputConfig"] = self._tensorboard_output_config + args["RetryStrategy"] = self._retry_strategy + args["InfraCheckConfig"] = self._infra_check_config + args["SessionChainingConfig"] = self._session_chaining_config + return serialize(args) + else: + args = {} + args["training_job_name"] = current_training_job_name + args["algorithm_specification"] = algorithm_specification + args["hyper_parameters"] = string_hyper_parameters + args["input_data_config"] = final_input_data_config + args["resource_config"] = resource_config + args["vpc_config"] = vpc_config + args["session"] = self.sagemaker_session.boto_session + args["role_arn"] = self.role + args["tags"] = self.tags + args["stopping_condition"] = self.stopping_condition + args["output_data_config"] = self.output_data_config + args["checkpoint_config"] = self.checkpoint_config + args["environment"] = self.environment + args["enable_managed_spot_training"] = self.compute.enable_managed_spot_training + args["enable_inter_container_traffic_encryption"] = ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ) + args["enable_network_isolation"] = ( + self.networking.enable_network_isolation if self.networking else None + ) + args["remote_debug_config"] = self._remote_debug_config + args["tensor_board_output_config"] = self._tensorboard_output_config + args["retry_strategy"] = self._retry_strategy + args["infra_check_config"] = self._infra_check_config + args["session_chaining_config"] = self._session_chaining_config + return args + @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") + @validate_call + def train( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + wait: Optional[bool] = True, + logs: Optional[bool] = True, + ): + """Train a model using AWS SageMaker. + + Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + wait (Optional[bool]): + Whether to wait for the training job to complete before returning. + Defaults to True. + logs (Optional[bool]): + Whether to display the training container logs while training. + Defaults to True. + """ + args = self._create_training_job_args(input_data_config=input_data_config) + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + training_job = TrainingJob.create(**args) + self._latest_training_job = training_job if wait: training_job.wait(logs=logs) if logs and not wait: @@ -780,19 +836,21 @@ def train( ) else: local_container = _LocalContainer( - training_job_name=_get_unique_name(self.base_job_name), - instance_type=resource_config.instance_type, - instance_count=resource_config.instance_count, - image=algorithm_specification.training_image, + training_job_name=args["training_job_name"], + instance_type=args["resource_config"].instance_type, + instance_count=args["resource_config"].instance_count, + image=args["algorithm_specification"].training_image, container_root=self.local_container_root, sagemaker_session=self.sagemaker_session, - container_entrypoint=algorithm_specification.container_entrypoint, - container_arguments=algorithm_specification.container_arguments, - input_data_config=final_input_data_config, - hyper_parameters=string_hyper_parameters, - environment=self.environment, + container_entrypoint=args["algorithm_specification"].container_entrypoint, + container_arguments=args["algorithm_specification"].container_arguments, + input_data_config=args["input_data_config"], + hyper_parameters=args["hyper_parameters"], + environment=args["environment"], ) local_container.train(wait) + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() def create_input_data_channel( self, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2ff561d784..705d9892fe 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -782,7 +782,7 @@ def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tag return all_tags - def train( # noqa: C901 + def get_train_request( self, input_mode, input_config, @@ -817,7 +817,7 @@ def train( # noqa: C901 retry_strategy=None, remote_debug_config=None, session_chaining_config=None, - ): + ) -> Dict: """Create an Amazon SageMaker training job. Args: @@ -960,12 +960,7 @@ def train( # noqa: C901 "EnableInfraCheck": True, } Returns: - str: ARN of the training job, if it is created. - - Raises: - - botocore.exceptions.ClientError: If Sagemaker throws an exception while creating - training job. - - ValueError: If both image_uri and algorithm are provided, or if neither is provided. + Dict: a Dict containing CreateTrainingJob request. """ tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( @@ -1047,6 +1042,228 @@ def train( # noqa: C901 environment=environment, retry_strategy=retry_strategy, ) + return train_request + + def train( # noqa: C901 + self, + input_mode, + input_config, + role=None, + job_name=None, + output_config=None, + resource_config=None, + vpc_config=None, + hyperparameters=None, + stop_condition=None, + tags=None, + metric_definitions=None, + enable_network_isolation=None, + image_uri=None, + training_image_config=None, + infra_check_config=None, + container_entry_point=None, + container_arguments=None, + algorithm_arn=None, + encrypt_inter_container_traffic=None, + use_spot_instances=False, + checkpoint_s3_uri=None, + checkpoint_local_path=None, + experiment_config=None, + debugger_rule_configs=None, + debugger_hook_config=None, + tensorboard_output_config=None, + enable_sagemaker_metrics=None, + profiler_rule_configs=None, + profiler_config=None, + environment: Optional[Dict[str, str]] = None, + retry_strategy=None, + remote_debug_config=None, + session_chaining_config=None, + ): + """Create an Amazon SageMaker training job. + + Args: + input_mode (str): The input mode that the algorithm supports. Valid modes: + * 'File' - Amazon SageMaker copies the training dataset from the S3 location to + a directory in the Docker container. + * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a + Unix-named pipe. + * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of + downloading the entire dataset before training begins. + input_config (list): A list of Channel objects. Each channel is a named input source. + Please refer to the format details described: + https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training + jobs and APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. You must grant sufficient permissions to this + role. + job_name (str): Name of the training job being created. + output_config (dict): The S3 URI where you want to store the training results and + optional KMS key ID. + resource_config (dict): Contains values for ResourceConfig: + * instance_count (int): Number of EC2 instances to use for training. + The key in resource_config is 'InstanceCount'. + * instance_type (str): Type of EC2 instance to use for training, for example, + 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. + vpc_config (dict): Contains values for VpcConfig: + * subnets (list[str]): List of subnet ids. + The key in vpc_config is 'Subnets'. + * security_group_ids (list[str]): List of security group ids. + The key in vpc_config is 'SecurityGroupIds'. + hyperparameters (dict): Hyperparameters for model training. The hyperparameters are + made accessible as a dict[str, str] to the training code on SageMaker. For + convenience, this accepts other types for keys and values, but ``str()`` will be + called to convert them before training. + stop_condition (dict): Defines when training shall finish. Contains entries that can + be understood by the service like ``MaxRuntimeInSeconds``. + tags (Optional[Tags]): Tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) + used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for + the name of the metric, and 'Regex' for the regular expression used to extract the + metric from the logs. + enable_network_isolation (bool): Whether to request for the training job to run with + network isolation or not. + image_uri (str): Docker image containing training code. + training_image_config(dict): Training image configuration. + Optionally, the dict can contain 'TrainingRepositoryAccessMode' and + 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). + For example, + + .. code:: python + + training_image_config = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": + "arn:aws:lambda:us-west-2:1234567890:function:test" + }, + } + + If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed + through a private Docker registry in customer Vpc. If it's set to Platform or None, + the training image is accessed through ECR. + If TrainingRepositoryCredentialsProviderArn is provided, the credentials to + authenticate to the private Docker registry will be retrieved from this AWS Lambda + function. (default: ``None``). When it's set to None, SageMaker will not do + authentication before pulling the image in the private Docker registry. + container_entry_point (List[str]): Optional. The entrypoint script for a Docker + container used to run a training job. This script takes precedence over + the default train processing instructions. + container_arguments (List[str]): Optional. The arguments for a container used to run + a training job. + algorithm_arn (str): Algorithm Arn from Marketplace. + encrypt_inter_container_traffic (bool): Specifies whether traffic between training + containers is encrypted for the training job (default: ``False``). + use_spot_instances (bool): whether to use spot instances for training. + checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints + that the algorithm persists (if any) during training. (default: + ``None``). + checkpoint_local_path (str): The local path that the algorithm + writes its checkpoints to. SageMaker will persist all files + under this path to `checkpoint_s3_uri` continually during + training. On job startup the reverse happens - data from the + s3 location is downloaded to this path before the algorithm is + started. If the path is unset then SageMaker assumes the + checkpoints will be provided under `/opt/ml/checkpoints/`. + (default: ``None``). + experiment_config (dict[str, str]): Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + The behavior of setting these keys is as follows: + * If `ExperimentName` is supplied but `TrialName` is not a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If `TrialName` is supplied and the Trial already exists the job's Trial Component + will be associated with the Trial. + * If both `ExperimentName` and `TrialName` are not supplied the trial component + will be unassociated. + * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. + enable_sagemaker_metrics (bool): enable SageMaker Metrics Time + Series. For more information see: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html + #SageMaker-Type + -AlgorithmSpecification-EnableSageMakerMetricsTimeSeries + (default: ``None``). + profiler_rule_configs (list[dict]): A list of profiler rule + configurations.src/sagemaker/lineage/artifact.py:285 + profiler_config (dict): Configuration for how profiling information is emitted + with SageMaker Profiler. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } + environment (dict[str, str]) : Environment variables to be set for + use during training job (default: ``None``) + retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. + * max_retry_attsmpts (int): Number of times a job should be retried. + The key in RetryStrategy is 'MaxRetryAttempts'. + infra_check_config(dict): Infra check configuration. + Optionally, the dict can contain 'EnableInfraCheck'(bool). + For example, + + .. code:: python + + infra_check_config = { + "EnableInfraCheck": True, + } + Returns: + str: ARN of the training job, if it is created. + + Raises: + - botocore.exceptions.ClientError: If Sagemaker throws an exception while creating + training job. + - ValueError: If both image_uri and algorithm are provided, or if neither is provided. + """ + train_request = self.get_train_request( + input_mode, + input_config, + role, + job_name, + output_config, + resource_config, + vpc_config, + hyperparameters, + stop_condition, + tags, + metric_definitions, + enable_network_isolation, + image_uri, + training_image_config, + infra_check_config, + container_entry_point, + container_arguments, + algorithm_arn, + encrypt_inter_container_traffic, + use_spot_instances, + checkpoint_s3_uri, + checkpoint_local_path, + experiment_config, + debugger_rule_configs, + debugger_hook_config, + tensorboard_output_config, + enable_sagemaker_metrics, + profiler_rule_configs, + profiler_config, + environment, + retry_strategy, + remote_debug_config, + session_chaining_config, + ) def submit(request): try: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d4faa5ad9f..2a31dfab04 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1502,6 +1502,24 @@ def instance_supports_kms(instance_type: str) -> bool: return volume_size_supported(instance_type) +def get_training_job_name_from_training_job_arn(training_job_arn: str) -> str: + """Extract Training job name from Training job arn. + + Args: + training_job_arn: Training job arn. + + Returns: Training job name. + + """ + if training_job_arn is None: + return None + pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)" + match = re.match(pattern, training_job_arn) + if match: + return match.group(1) + return None + + def get_instance_type_family(instance_type: str) -> str: """Return the family of the instance type. diff --git a/tests/data/modules/script_mode/custom_script.py b/tests/data/modules/script_mode/custom_script.py index 26e5826267..a57ddee743 100644 --- a/tests/data/modules/script_mode/custom_script.py +++ b/tests/data/modules/script_mode/custom_script.py @@ -76,14 +76,60 @@ def predict_fn(input_data, model): return model(input_data.float()).numpy()[0] +def parse_args(): + """ + Parse the command line arguments + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + type=str, + default=os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")), + help="Directory to save the model", + ) + parser.add_argument( + "--train-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TRAIN", os.path.join(current_dir, "data/train")), + help="Directory containing training data", + ) + parser.add_argument( + "--test-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TEST", os.path.join(current_dir, "data/test")), + help="Directory containing testing data", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Batch size for training", + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=0.1, + help="Learning rate for training", + ) + return parser.parse_args() + + def train(): """ Train the PyTorch model """ + args = parse_args() # Directories: train, test and model - train_dir = os.path.join(current_dir, "data/train") - test_dir = os.path.join(current_dir, "data/test") - model_dir = os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")) + train_dir = args.train_dir + test_dir = args.test_dir + model_dir = args.model_dir # Load the training and testing data x_train, y_train = get_train_data(train_dir) @@ -91,9 +137,9 @@ def train(): train_ds = TensorDataset(x_train, y_train) # Training parameters - used to configure the training loop - batch_size = 64 - epochs = 1 - learning_rate = 0.1 + batch_size = args.batch_size + epochs = args.epochs + learning_rate = args.learning_rate logger.info( "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) ) diff --git a/tests/integ/sagemaker/aws_batch/__init__.py b/tests/integ/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/aws_batch/manager.py b/tests/integ/sagemaker/aws_batch/manager.py new file mode 100644 index 0000000000..b417f86b53 --- /dev/null +++ b/tests/integ/sagemaker/aws_batch/manager.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import time + + +class BatchTestResourceManager: + + def __init__( + self, + batch_client, + queue_name="pysdk-test-queue", + service_env_name="pysdk-test-queue-service-environment", + ): + self.batch_client = batch_client + self.queue_name = queue_name + self.service_environment_name = service_env_name + + def _create_or_get_service_environment(self, service_environment_name): + print(f"Creating service environment: {service_environment_name}") + try: + response = self.batch_client.create_service_environment( + serviceEnvironmentName=service_environment_name, + serviceEnvironmentType="SAGEMAKER_TRAINING", + capacityLimits=[{"maxCapacity": 10, "capacityUnit": "NUM_INSTANCES"}], + ) + print(f"Service environment {service_environment_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + return response["serviceEnvironments"][0] + else: + print(f"Error creating service environment: {e}") + raise + + def _create_or_get_queue(self, queue_name, service_environment_arn): + + print(f"Creating job queue: {queue_name}") + try: + response = self.batch_client.create_job_queue( + jobQueueName=queue_name, + priority=1, + computeEnvironmentOrder=[], + serviceEnvironmentOrder=[ + { + "order": 1, + "serviceEnvironment": service_environment_arn, + }, + ], + jobQueueType="SAGEMAKER_TRAINING", + ) + print(f"Job queue {queue_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + return response["jobQueues"][0] + else: + print(f"Error creating job queue: {e}") + raise + + def _update_queue_state(self, queue_name, state): + try: + print(f"Updating queue {queue_name} to state {state}") + response = self.batch_client.update_job_queue(jobQueue=queue_name, state=state) + return response + except Exception as e: + print(f"Error updating queue: {e}") + + def _update_service_environment_state(self, service_environment_name, state): + print(f"Updating service environment {service_environment_name} to state {state}") + try: + response = self.batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state=state + ) + return response + except Exception as e: + print(f"Error updating service environment: {e}") + + def _wait_for_queue_state(self, queue_name, state): + print(f"Waiting for queue {queue_name} to be {state}...") + while True: + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + print(f"Current state: {response}") + if response["jobQueues"][0]["state"] == state: + break + time.sleep(5) + print(f"Queue {queue_name} is now {state}.") + + def _wait_for_service_environment_state(self, service_environment_name, state): + print(f"Waiting for service environment {service_environment_name} to be {state}...") + while True: + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + print(f"Current state: {response}") + if response["serviceEnvironments"][0]["state"] == state: + break + time.sleep(5) + print(f"Service environment {service_environment_name} is now {state}.") + + def get_or_create_resources(self, queue_name=None, service_environment_name=None): + queue_name = queue_name or self.queue_name + service_environment_name = service_environment_name or self.service_environment_name + + service_environment = self._create_or_get_service_environment(service_environment_name) + if service_environment.get("state") != "ENABLED": + self._update_service_environment_state(service_environment_name, "ENABLED") + self._wait_for_service_environment_state(service_environment_name, "ENABLED") + time.sleep(10) + + queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"]) + if queue.get("state") != "ENABLED": + self._update_queue_state(queue_name, "ENABLED") + self._wait_for_queue_state(queue_name, "ENABLED") + time.sleep(10) + return queue, service_environment diff --git a/tests/integ/sagemaker/aws_batch/test_queue.py b/tests/integ/sagemaker/aws_batch/test_queue.py new file mode 100644 index 0000000000..20b8de55c1 --- /dev/null +++ b/tests/integ/sagemaker/aws_batch/test_queue.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +import botocore +import pytest + +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules.configs import SourceCode, InputData, Compute + +from sagemaker.aws_batch.training_queue import TrainingQueue + +from tests.integ import DATA_DIR +from tests.integ.sagemaker.modules.conftest import modules_sagemaker_session # noqa: F401 +from tests.integ.sagemaker.modules.train.test_model_trainer import ( + DEFAULT_CPU_IMAGE, +) +from tests.integ.sagemaker.aws_batch.manager import BatchTestResourceManager + + +@pytest.fixture(scope="module") +def batch_client(): + return boto3.client("batch", region_name="us-west-2") + + +@pytest.fixture(scope="function") +def batch_test_resource_manager(batch_client): + resource_manager = BatchTestResourceManager(batch_client=batch_client) + resource_manager.get_or_create_resources() + return resource_manager + + +def test_model_trainer_submit(batch_test_resource_manager, modules_sagemaker_session): # noqa: F811 + queue_name = batch_test_resource_manager.queue_name + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/script_mode/", + requirements="requirements.txt", + entry_script="custom_script.py", + ) + hyperparameters = { + "batch-size": 32, + "epochs": 1, + "learning-rate": 0.01, + } + compute = Compute(instance_type="ml.m5.2xlarge") + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + source_code=source_code, + compute=compute, + hyperparameters=hyperparameters, + base_job_name="test-batch-model-trainer", + ) + train_data = InputData( + channel_name="train", + data_source=f"{DATA_DIR}/modules/script_mode/data/train/", + ) + test_data = InputData( + channel_name="test", + data_source=f"{DATA_DIR}/modules/script_mode/data/test/", + ) + + training_queue = TrainingQueue(queue_name=queue_name) + + try: + queued_job = training_queue.submit( + training_job=model_trainer, + inputs=[train_data, test_data], + ) + except botocore.exceptions.ClientError as e: + print(e.response["ResponseMetadata"]) + print(e.response["Error"]["Message"]) + raise e + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUBMITTED" + + queued_job.wait(timeout=1800) + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUCCEEDED" diff --git a/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py b/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor_integ.py similarity index 100% rename from tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py rename to tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor_integ.py diff --git a/tests/integ/sagemaker/modules/conftest.py b/tests/integ/sagemaker/modules/conftest.py index c3de81157a..d6d3877de4 100644 --- a/tests/integ/sagemaker/modules/conftest.py +++ b/tests/integ/sagemaker/modules/conftest.py @@ -29,7 +29,7 @@ def modules_sagemaker_session(): os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION region_manual_set = True else: - region_manual_set = True + region_manual_set = False boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) sagemaker_session = Session(boto_session=boto_session) diff --git a/tests/unit/sagemaker/aws_batch/__init__.py b/tests/unit/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/aws_batch/constants.py b/tests/unit/sagemaker/aws_batch/constants.py new file mode 100644 index 0000000000..8745e3558f --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/constants.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import + + +TRAINING_JOB_NAME = "my-training-job" +TRAINING_IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.8.0-cpu-py3" +TRAINING_INPUT_MODE = "File" +CONTAINER_ENTRYPOINT = ["echo", "hello"] +EXECUTION_ROLE = "myrole" +S3_OUTPUT_PATH = "s3://output" +INSTANCE_TYPE = "ml.m4.xlarge" +INSTANCE_COUNT = 1 +VOLUME_SIZE_IN_GB = 1 +MAX_RUNTIME_IN_SECONDS = 600 +TRAINING_JOB_ARN = "arn:aws:sagemaker:us-west-2:476748761737:training-job/jobName" +JOB_NAME = "jobName" +JOB_NAME_IN_PAYLOAD = "jobNameInPayload" +JOB_ID = "123" +JOB_ARN = "arn:batch:job" +JOB_QUEUE = "testQueue" +JOB_STATUS_RUNNABLE = "RUNNABLE" +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +NEXT_TOKEN = "SomeNextToken" +SCHEDULING_PRIORITY = 1 +ATTEMPT_DURATION_IN_SECONDS = 100 +REASON = "killed by Batch API" +SHARE_IDENTIFIER = "shareId" +BATCH_TAGS = {"batch_k": "batch_v"} +TRAINING_TAGS = [{"Key": "training_k", "Value": "training_v"}] +TRAINING_TAGS_DUPLICATING_BATCH_TAGS = [ + *TRAINING_TAGS, + {"Key": "batch_k", "Value": "this value should win"}, +] +TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS = {"training_k": "training_v"} +MERGED_TAGS = {**BATCH_TAGS, **TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS} +MERGED_TAGS_TRAINING_OVERRIDE = { + **TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS, + "batch_k": "this value should win", +} +EXPERIMENT_CONFIG_EMPTY = {} + +TRAINING_JOB_PAYLOAD_IN_PASCALCASE = {"TrainingJobName": JOB_NAME_IN_PAYLOAD} +TIMEOUT_CONFIG = {"attemptDurationSeconds": ATTEMPT_DURATION_IN_SECONDS} +SUBMIT_SERVICE_JOB_RESP = {"jobArn": JOB_ARN, "jobName": JOB_NAME, "jobId": JOB_ID} +FIRST_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [{"jobName": JOB_NAME, "jobArn": JOB_ARN}], + "nextToken": NEXT_TOKEN, +} +SECOND_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN}, + {"jobName": JOB_NAME, "jobArn": JOB_ARN}, + ], + "nextToken": NEXT_TOKEN, +} +INCORRECT_FIRST_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [{"jobName": JOB_NAME}], + "nextToken": NEXT_TOKEN, +} +EMPTY_LIST_SERVICE_JOB_RESP = {"jobSummaryList": [], "nextToken": None} +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/tests/unit/sagemaker/aws_batch/mock_client.py b/tests/unit/sagemaker/aws_batch/mock_client.py new file mode 100644 index 0000000000..c13bb9db93 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/mock_client.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from typing import Optional, List, Dict +from .constants import ( + JOB_ARN, + JOB_ID, + FIRST_LIST_SERVICE_JOB_RESP, + EMPTY_LIST_SERVICE_JOB_RESP, + JOB_STATUS_RUNNING, + TIMEOUT_CONFIG, +) + + +class MockClient: + def submit_service_job( + self, + jobName, + jobQueue, + serviceRequestPayload, + serviceJobType, + retryStrategy: Optional[Dict] = None, + schedulingPriority: Optional[int] = None, + shareIdentifier: Optional[str] = "", + tags: Optional[Dict] = None, + timeoutConfig: Optional[Dict] = TIMEOUT_CONFIG, + ): + return {"jobArn": JOB_ARN, "jobName": jobName, "jobId": JOB_ID} + + def describe_service_job(self, jobId): + return {"jobId": jobId} + + def terminate_service_job(self, jobId, reason): + return {} + + def list_service_jobs( + self, + jobQueue, + jobStatus: Optional[str] = JOB_STATUS_RUNNING, + nextToken: Optional[str] = "", + filters: Optional[List] = [], + ): + if nextToken: + return FIRST_LIST_SERVICE_JOB_RESP + else: + return EMPTY_LIST_SERVICE_JOB_RESP diff --git a/tests/unit/sagemaker/aws_batch/mock_estimator.py b/tests/unit/sagemaker/aws_batch/mock_estimator.py new file mode 100644 index 0000000000..aa3d9e1b20 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/mock_estimator.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from sagemaker.estimator import Estimator +from sagemaker.pytorch import PyTorch + + +class Estimator(Estimator): + def __init__(self): + self.sagemaker_session = Session() + self.tags = [ + {"Key": "batch-non-prod", "Value": "true"}, + {"Key": "batch-training-job-name", "Value": "training-job"}, + ] + + def prepare_workflow_for_training(self, job_name): + pass + + +class PyTorch(PyTorch): + def __init__(self): + self.sagemaker_session = Session() + self.tags = [ + {"Key": "batch-non-prod", "Value": "true"}, + {"Key": "batch-training-job-name", "Value": "training-job"}, + ] + + def prepare_workflow_for_training(self, job_name): + pass + + +class Session: + def __init__(self): + pass + + def get_train_request(self, **kwargs): + return kwargs diff --git a/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py b/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py new file mode 100644 index 0000000000..e9384c135c --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py @@ -0,0 +1,186 @@ +from __future__ import absolute_import +from sagemaker.aws_batch.batch_api_helper import ( + submit_service_job, + terminate_service_job, + describe_service_job, + list_service_job, + __merge_tags, +) + +import json +import pytest +from mock.mock import patch + +from sagemaker.aws_batch.constants import ( + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SAGEMAKER_TRAINING, +) +from .mock_client import MockClient +from .constants import ( + JOB_NAME, + JOB_QUEUE, + SCHEDULING_PRIORITY, + JOB_ID, + REASON, + SHARE_IDENTIFIER, + BATCH_TAGS, + TRAINING_TAGS, + TRAINING_TAGS_DUPLICATING_BATCH_TAGS, + TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS, + MERGED_TAGS, + MERGED_TAGS_TRAINING_OVERRIDE, + JOB_STATUS_RUNNING, + NEXT_TOKEN, +) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_submit_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + training_payload = {} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +@patch("sagemaker.aws_batch.batch_api_helper.__merge_tags") +@pytest.mark.parametrize( + "batch_tags,training_tags", + [ + (BATCH_TAGS, TRAINING_TAGS), + (None, TRAINING_TAGS), + ({}, TRAINING_TAGS), + (BATCH_TAGS, None), + (BATCH_TAGS, []), + ], +) +def test_submit_service_job_called_with_merged_tags( + patched_merge_tags, patched_get_batch_boto_client, batch_tags, training_tags +): + mock_client = MockClient() + patched_get_batch_boto_client.return_value = mock_client + patched_merge_tags.return_value = MERGED_TAGS + + with patch.object( + mock_client, "submit_service_job", wraps=mock_client.submit_service_job + ) as wrapped_submit_service_job: + training_payload = {"Tags": training_tags} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + batch_tags, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + patched_merge_tags.assert_called_once_with(batch_tags, training_tags) + wrapped_submit_service_job.assert_called_once_with( + jobName=JOB_NAME, + jobQueue=JOB_QUEUE, + retryStrategy=DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + serviceJobType=SAGEMAKER_TRAINING, + serviceRequestPayload=json.dumps(training_payload), + timeoutConfig=DEFAULT_TIMEOUT, + schedulingPriority=SCHEDULING_PRIORITY, + shareIdentifier=SHARE_IDENTIFIER, + tags={**MERGED_TAGS}, + ) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +@patch("sagemaker.aws_batch.batch_api_helper.__merge_tags") +def test_submit_service_job_not_called_with_tags(patched_merge_tags, patched_get_batch_boto_client): + mock_client = MockClient() + patched_get_batch_boto_client.return_value = mock_client + patched_merge_tags.return_value = MERGED_TAGS + + with patch.object( + mock_client, "submit_service_job", wraps=mock_client.submit_service_job + ) as wrapped_submit_service_job: + training_payload = {} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + patched_merge_tags.assert_not_called() + wrapped_submit_service_job.assert_called_once_with( + jobName=JOB_NAME, + jobQueue=JOB_QUEUE, + retryStrategy=DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + serviceJobType=SAGEMAKER_TRAINING, + serviceRequestPayload=json.dumps(training_payload), + timeoutConfig=DEFAULT_TIMEOUT, + schedulingPriority=SCHEDULING_PRIORITY, + shareIdentifier=SHARE_IDENTIFIER, + ) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_describe_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + resp = describe_service_job(job_id=JOB_ID) + assert resp["jobId"] == JOB_ID + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_terminate_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + resp = terminate_service_job(job_id=JOB_ID, reason=REASON) + assert len(resp) == 0 + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_list_service_job_has_next_token(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + gen = list_service_job(job_queue=None, job_status=JOB_STATUS_RUNNING, next_token=NEXT_TOKEN) + resp = next(gen) + assert resp["nextToken"] == NEXT_TOKEN + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_list_service_job_no_next_token(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + gen = list_service_job(job_queue=None, job_status=JOB_STATUS_RUNNING, next_token=None) + resp = next(gen) + assert resp["nextToken"] is None + + +@pytest.mark.parametrize( + "batch_tags,training_tags,expected", + [ + (BATCH_TAGS, TRAINING_TAGS, MERGED_TAGS), + (BATCH_TAGS, TRAINING_TAGS_DUPLICATING_BATCH_TAGS, MERGED_TAGS_TRAINING_OVERRIDE), + (BATCH_TAGS, None, BATCH_TAGS), + (BATCH_TAGS, [], BATCH_TAGS), + (None, TRAINING_TAGS, TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS), + ({}, TRAINING_TAGS, TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS), + ], +) +def test___merge_tags(batch_tags, training_tags, expected): + result = __merge_tags(batch_tags=batch_tags, training_tags=training_tags) + assert result == expected diff --git a/tests/unit/sagemaker/aws_batch/test_training_queue.py b/tests/unit/sagemaker/aws_batch/test_training_queue.py new file mode 100644 index 0000000000..6fee3efad7 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_training_queue.py @@ -0,0 +1,411 @@ +from __future__ import absolute_import +from sagemaker.aws_batch.constants import DEFAULT_TIMEOUT +from sagemaker.aws_batch.exception import MissingRequiredArgument +from sagemaker.aws_batch.training_queue import TrainingQueue + +from unittest.mock import Mock, call +from mock.mock import patch +import pytest + +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from sagemaker.estimator import _TrainingJob +from .constants import ( + JOB_QUEUE, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + JOB_ARN, + SUBMIT_SERVICE_JOB_RESP, + JOB_NAME_IN_PAYLOAD, + JOB_STATUS_RUNNING, + EMPTY_LIST_SERVICE_JOB_RESP, + FIRST_LIST_SERVICE_JOB_RESP, + INCORRECT_FIRST_LIST_SERVICE_JOB_RESP, + EXPERIMENT_CONFIG_EMPTY, + SECOND_LIST_SERVICE_JOB_RESP, + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, +) +from .mock_estimator import Estimator, PyTorch + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_with_timeout(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_use_default_timeout(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + None, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_with_job_name(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + Estimator(), + {}, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME_IN_PAYLOAD, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_encounter_error(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = {} + + queue = TrainingQueue(JOB_QUEUE) + with pytest.raises(MissingRequiredArgument): + queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + + +def test_queue_map_with_job_names_mismatch_input_length_encounter_error(): + queue = TrainingQueue(JOB_QUEUE) + with pytest.raises(ValueError): + queue.map(Estimator(), {}, [JOB_NAME]) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_map_happy_case(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + input_list = {"test-input", "test-input-2"} + + queue = TrainingQueue(JOB_QUEUE) + queue.map( + Estimator(), + input_list, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + assert patched_submit_service_job.call_count == len(input_list) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_map_with_job_names(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + input_list = {"test-input", "test-input-2"} + job_names = [JOB_NAME, "job-name-2"] + + queue = TrainingQueue(JOB_QUEUE) + queue.map( + Estimator(), + input_list, + job_names, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + assert patched_submit_service_job.call_count == len(input_list) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_default_argument(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + patched_list_service_job.return_value = [{"jobSummaryList": [], "nextToken": None}] + queue.list_jobs() + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, JOB_STATUS_RUNNING, None, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_with_job_name(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [{"jobSummaryList": [], "nextToken": None}] + + queue.list_jobs(JOB_NAME, None) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, None, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_with_job_status(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = None + + patched_list_service_job.return_value = [EMPTY_LIST_SERVICE_JOB_RESP] + + queue.list_jobs(None, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, JOB_STATUS_RUNNING, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_has_next_token(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + first_output = FIRST_LIST_SERVICE_JOB_RESP + second_output = SECOND_LIST_SERVICE_JOB_RESP + third_output = EMPTY_LIST_SERVICE_JOB_RESP + patched_list_service_job.return_value = iter([first_output, second_output, third_output]) + + jobs = queue.list_jobs(JOB_NAME, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert len(jobs) == 3 + assert jobs[0].job_arn == JOB_ARN + assert jobs[0].job_name == JOB_NAME + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_without_job_arn_in_list_resp(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + first_output = INCORRECT_FIRST_LIST_SERVICE_JOB_RESP + second_output = EMPTY_LIST_SERVICE_JOB_RESP + patched_list_service_job.return_value = iter([first_output, second_output]) + + jobs = queue.list_jobs(JOB_NAME, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert len(jobs) == 0 + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_get_happy_case_job_exists(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [FIRST_LIST_SERVICE_JOB_RESP] + + job = queue.get_job(JOB_NAME) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert job.job_name == JOB_NAME + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_get_job_not_found_encounter_error(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [EMPTY_LIST_SERVICE_JOB_RESP] + + with pytest.raises(ValueError): + queue.get_job(JOB_NAME) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, None, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_submit_model_trainer(patch_submit_service_job): + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + payload = { + "TrainingJobName": JOB_NAME, + "ResourceConfig": { + "InstanceType": "ml.m5.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, + }, + } + trainer._create_training_job_args.return_value = payload + + patch_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patch_submit_service_job.assert_called_once_with( + payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +def test_submit_model_trainer_fail(): + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.LOCAL_CONTAINER + + with pytest.raises( + ValueError, + match="TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB", + ): + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_submit_pytorch_estimator(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + PyTorch(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + DEFAULT_TIMEOUT, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +def test_submit_with_invalid_training_job(): + with pytest.raises( + TypeError, + match="training_job must be an instance of EstimatorBase or ModelTrainer", + ): + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + TrainingQueue("NotAnEstimatorOrModelTrainer"), + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) diff --git a/tests/unit/sagemaker/aws_batch/test_training_queued_job.py b/tests/unit/sagemaker/aws_batch/test_training_queued_job.py new file mode 100644 index 0000000000..fe5231a01d --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_training_queued_job.py @@ -0,0 +1,170 @@ +from __future__ import absolute_import + +import pytest +import time +from mock.mock import patch +from unittest.mock import Mock + +from sagemaker.aws_batch.exception import NoTrainingJob, MissingRequiredArgument +from sagemaker.aws_batch.training_queued_job import TrainingQueuedJob +from sagemaker.config import SAGEMAKER, TRAINING_JOB +from .constants import ( + JOB_ARN, + JOB_NAME, + REASON, + TRAINING_IMAGE, + JOB_STATUS_RUNNING, + JOB_STATUS_RUNNABLE, + JOB_STATUS_FAILED, + JOB_STATUS_COMPLETED, + EXECUTION_ROLE, + TRAINING_JOB_ARN, +) +from tests.unit import SAGEMAKER_CONFIG_TRAINING_JOB + + +@patch("sagemaker.aws_batch.training_queued_job.terminate_service_job") +def test_queued_job_terminate(patched_terminate_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate(REASON) + patched_terminate_service_job.assert_called_once_with(queued_job.job_arn, REASON) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_describe(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.describe() + patched_describe_service_job.assert_called_once_with(queued_job.job_arn) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_estimator_no_training_job_created(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNABLE} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + with pytest.raises(NoTrainingJob): + queued_job.get_estimator() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_estimator_missing_required_argument(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + with pytest.raises(MissingRequiredArgument): + queued_job.get_estimator() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@patch("sagemaker.aws_batch.training_queued_job._construct_estimator_from_training_job_name") +def test_queued_job_estimator_happy_case( + patched_construct_estimator_from_training_job_name, patched_describe_service_job +): + training_job_config = SAGEMAKER_CONFIG_TRAINING_JOB[SAGEMAKER][TRAINING_JOB] + training_job_config["image_uri"] = TRAINING_IMAGE + training_job_config["job_name"] = JOB_NAME + training_job_config["role"] = EXECUTION_ROLE + describe_resp = { + "status": JOB_STATUS_RUNNING, + "latestAttempt": { + "serviceResourceId": {"name": "trainingJobArn", "value": TRAINING_JOB_ARN} + }, + } + patched_describe_service_job.return_value = describe_resp + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.get_estimator() + patched_construct_estimator_from_training_job_name.assert_called_once_with(JOB_NAME) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_no_timeout(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + assert result.get("status", "") == JOB_STATUS_COMPLETED + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_with_timeout_succeeds(patched_describe_service_job): + patched_describe_service_job.side_effect = [ + {"status": JOB_STATUS_RUNNING}, + {"status": JOB_STATUS_RUNNING}, + {"status": JOB_STATUS_COMPLETED}, + ] + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=15) + end_time = time.time() + + assert end_time - start_time < 15 + assert result.get("status", "") == JOB_STATUS_COMPLETED + assert patched_describe_service_job.call_count == 3 + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_with_timeout_times_out(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=5) + end_time = time.time() + + assert end_time - start_time > 5 + assert result.get("status", "") == JOB_STATUS_RUNNING + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_happy_case(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + # queued_job.describe.return_value = {"status": JOB_STATUS_COMPLETED} + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + + result = await queued_job.fetch_job_results() + assert result == {"status": JOB_STATUS_COMPLETED} + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_job_failed(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + patched_describe_service_job.return_value = { + "status": JOB_STATUS_FAILED, + "statusReason": "Job failed", + } + + with pytest.raises(RuntimeError): + await queued_job.fetch_job_results() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_timeout(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + + with pytest.raises(TimeoutError): + await queued_job.fetch_job_results(timeout=1) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queue_result_happy_case(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + + result = queued_job.result(100) + assert result == {"status": JOB_STATUS_COMPLETED} + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queue_result_job_times_out(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + + with pytest.raises(TimeoutError): + queued_job.result(1) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 184f9c30da..73893ea7f4 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1302,6 +1302,53 @@ def mock_upload_data(path, bucket, key_prefix): assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" +def test_create_training_job_args(modules_session): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + + args = model_trainer._create_training_job_args() + assert args["algorithm_specification"] == AlgorithmSpecification( + training_image=DEFAULT_IMAGE, + algorithm_name=None, + training_input_mode="File", + container_entrypoint=None, + container_arguments=None, + training_image_config=None, + metric_definitions=None, + ) + assert args["resource_config"] == ResourceConfig( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + ) + assert args["role_arn"] == DEFAULT_ROLE + + +def test_create_training_job_args_boto3(modules_session): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + + args = model_trainer._create_training_job_args(boto3=True) + assert args["AlgorithmSpecification"] == { + "TrainingImage": DEFAULT_IMAGE, + "TrainingInputMode": "File", + } + assert args["ResourceConfig"] == { + "InstanceType": DEFAULT_INSTANCE_TYPE, + "InstanceCount": 1, + "VolumeSizeInGB": 30, + } + assert args["RoleArn"] == DEFAULT_ROLE + + @patch("sagemaker.modules.train.model_trainer.TrainingJob") def test_input_merge(mock_training_job, modules_session): model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz") diff --git a/tox.ini b/tox.ini index e4df36587a..9c624b2052 100644 --- a/tox.ini +++ b/tox.ini @@ -68,6 +68,8 @@ markers = setenv = PYTHONHASHSEED=42 pip_version = pip==24.3 +allowlist_externals = + aws passenv = AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY @@ -82,6 +84,7 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" + pip install 'apache-airflow==2.10.4' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.4/constraints-3.9.txt" pip install 'torch==2.3.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' @@ -90,7 +93,11 @@ commands = pip install -U "sagemaker-core" # needed to keep sagemaker-core up to date pytest {posargs} -deps = .[test] +deps = + .[test] + asyncio + nest_asyncio + pytest-asyncio depends = {py39,py310,py311,py312}: clean