Skip to content

Commit 87b4058

Browse files
nmadanNamrata Madan
authored andcommitted
feature: Pipelines local mode setup
Co-authored-by: Namrata Madan <[email protected]>
1 parent be95f4e commit 87b4058

18 files changed

+805
-9
lines changed

src/sagemaker/local/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
LocalSagemakerClient,
1919
LocalSagemakerRuntimeClient,
2020
LocalSession,
21+
LocalPipelineSession,
2122
)

src/sagemaker/local/entities.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,23 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import enum
1617
import datetime
1718
import json
1819
import logging
1920
import os
2021
import tempfile
2122
import time
23+
from uuid import uuid4
24+
from copy import deepcopy
25+
from botocore.exceptions import ClientError
2226

2327
import sagemaker.local.data
28+
2429
from sagemaker.local.image import _SageMakerContainer
2530
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
2631
from sagemaker.utils import DeferredError, get_config_value
32+
from sagemaker.local.exceptions import StepExecutionException
2733

2834
logger = logging.getLogger(__name__)
2935

@@ -618,6 +624,205 @@ def describe(self):
618624
return response
619625

620626

627+
class _LocalPipeline(object):
628+
"""Placeholder docstring"""
629+
630+
_executions = {}
631+
632+
def __init__(
633+
self,
634+
pipeline,
635+
pipeline_description=None,
636+
local_session=None,
637+
):
638+
from sagemaker.local import LocalSession
639+
640+
self.local_session = local_session or LocalSession()
641+
self.pipeline = pipeline
642+
self.pipeline_description = pipeline_description
643+
now_time = datetime.datetime.now()
644+
self.creation_time = now_time
645+
self.last_modified_time = now_time
646+
647+
def describe(self):
648+
"""Placeholder docstring"""
649+
response = {
650+
"PipelineArn": self.pipeline.name,
651+
"PipelineDefinition": self.pipeline.definition(),
652+
"PipelineDescription": self.pipeline_description,
653+
"PipelineName": self.pipeline.name,
654+
"PipelineStatus": "Active",
655+
"RoleArn": "<no_role>",
656+
"CreationTime": self.creation_time,
657+
"LastModifiedTime": self.last_modified_time,
658+
}
659+
return response
660+
661+
def start(self, **kwargs):
662+
"""Placeholder docstring"""
663+
from sagemaker.local.pipeline import LocalPipelineExecutor
664+
665+
execution_id = str(uuid4())
666+
execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs)
667+
668+
self._executions[execution_id] = execution
669+
return LocalPipelineExecutor(execution, self.local_session).execute()
670+
671+
672+
class _LocalPipelineExecution(object):
673+
"""Placeholder docstring"""
674+
675+
def __init__(
676+
self,
677+
execution_id,
678+
pipeline,
679+
PipelineParameters=None,
680+
PipelineExecutionDescription=None,
681+
PipelineExecutionDisplayName=None,
682+
):
683+
self.pipeline = pipeline
684+
self.pipeline_execution_name = execution_id
685+
self.pipeline_execution_description = PipelineExecutionDescription
686+
self.pipeline_execution_display_name = PipelineExecutionDisplayName
687+
self.status = _LocalExecutionStatus.EXECUTING.value
688+
self.failure_reason = None
689+
self.creation_time = datetime.datetime.now()
690+
self.step_execution = self._initialize_step_execution()
691+
self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters)
692+
693+
def describe(self):
694+
"""Placeholder docstring"""
695+
response = {
696+
"CreationTime": self.creation_time,
697+
"LastModifiedTime": self.creation_time,
698+
"FailureReason": self.failure_reason,
699+
"PipelineArn": self.pipeline.name,
700+
"PipelineExecutionArn": self.pipeline_execution_name,
701+
"PipelineExecutionDescription": self.pipeline_execution_description,
702+
"PipelineExecutionDisplayName": self.pipeline_execution_display_name,
703+
"PipelineExecutionStatus": self.status,
704+
}
705+
filtered_response = {k: v for k, v in response.items() if v is not None}
706+
return filtered_response
707+
708+
def list_steps(self):
709+
"""Placeholder docstring"""
710+
# TODO
711+
712+
def update_execution_failure(self, step_name, failure_message):
713+
"""Mark execution as failed."""
714+
self.status = _LocalExecutionStatus.FAILED.value
715+
self.failure_reason = f"Step {step_name} failed with message: {failure_message}"
716+
logger.error("Pipeline execution failed because step %s failed.", step_name)
717+
718+
def update_step_failure(self, step_name, failure_message):
719+
"""Mark step_name as failed."""
720+
self.step_execution.get(step_name).update_step_failure(failure_message)
721+
722+
def mark_step_starting(self, step_name):
723+
"""Update step's status to EXECUTING"""
724+
self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING
725+
726+
def _initialize_step_execution(self):
727+
"""Initialize step_execution dict."""
728+
from sagemaker.workflow.steps import StepTypeEnum
729+
730+
supported_steps_types = (
731+
StepTypeEnum.TRAINING,
732+
StepTypeEnum.PROCESSING,
733+
StepTypeEnum.TRANSFORM,
734+
StepTypeEnum.CONDITION,
735+
StepTypeEnum.FAIL,
736+
)
737+
738+
step_execution = {}
739+
for step in self.pipeline.steps:
740+
if step.step_type not in supported_steps_types:
741+
error_msg = self._construct_validation_exception_message(
742+
"Step type {} is not supported in local mode.".format(step.step_type.value)
743+
)
744+
raise ClientError(error_msg, "start_pipeline_execution")
745+
step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type)
746+
return step_execution
747+
748+
def _initialize_and_validate_parameters(self, overridden_parameters):
749+
"""Initialize and validate pipeline parameters."""
750+
merged_parameters = {}
751+
default_parameters = {parameter.name: parameter for parameter in self.pipeline.parameters}
752+
if overridden_parameters is not None:
753+
for (param_name, param_value) in overridden_parameters.items():
754+
if param_name not in default_parameters:
755+
error_msg = self._construct_validation_exception_message(
756+
"Unknown parameter '{}'".format(param_name)
757+
)
758+
raise ClientError(error_msg, "start_pipeline_execution")
759+
parameter_type = default_parameters[param_name].parameter_type
760+
if type(param_value) != parameter_type.python_type: # pylint: disable=C0123
761+
error_msg = self._construct_validation_exception_message(
762+
"Unexpected type for parameter '{}'. Expected {} but found "
763+
"{}.".format(param_name, parameter_type.python_type, type(param_value))
764+
)
765+
raise ClientError(error_msg, "start_pipeline_execution")
766+
merged_parameters[param_name] = param_value
767+
for param_name, default_parameter in default_parameters.items():
768+
if param_name not in merged_parameters:
769+
if default_parameter.default_value is None:
770+
error_msg = self._construct_validation_exception_message(
771+
"Parameter '{}' is undefined.".format(param_name)
772+
)
773+
raise ClientError(error_msg, "start_pipeline_execution")
774+
merged_parameters[param_name] = default_parameter.default_value
775+
return merged_parameters
776+
777+
@staticmethod
778+
def _construct_validation_exception_message(exception_msg):
779+
"""Construct error response for botocore.exceptions.ClientError"""
780+
return {"Error": {"Code": "ValidationException", "Message": exception_msg}}
781+
782+
783+
class _LocalPipelineStepExecution(object):
784+
"""Placeholder docstring"""
785+
786+
def __init__(
787+
self,
788+
step_name,
789+
step_type,
790+
last_modified_time=None,
791+
status=None,
792+
properties=None,
793+
failure_reason=None,
794+
):
795+
self.step_name = step_name
796+
self.step_type = step_type
797+
self.status = status or _LocalExecutionStatus.STARTING
798+
self.failure_reason = failure_reason
799+
self.properties = properties or {}
800+
self.creation_time = datetime.datetime.now()
801+
self.last_modified_time = last_modified_time or self.creation_time
802+
803+
def update_step_properties(self, properties):
804+
"""Update pipeline step execution output properties."""
805+
logger.info("Successfully completed step %s.", self.step_name)
806+
self.properties = deepcopy(properties)
807+
self.status = _LocalExecutionStatus.SUCCEEDED.value
808+
809+
def update_step_failure(self, failure_message):
810+
"""Update pipeline step execution failure status and message."""
811+
logger.error(failure_message)
812+
self.failure_reason = failure_message
813+
self.status = _LocalExecutionStatus.FAILED.value
814+
raise StepExecutionException(self.step_name, failure_message)
815+
816+
817+
class _LocalExecutionStatus(enum.Enum):
818+
"""Placeholder docstring"""
819+
820+
STARTING = "Starting"
821+
EXECUTING = "Executing"
822+
SUCCEEDED = "Succeeded"
823+
FAILED = "Failed"
824+
825+
621826
def _wait_for_serving_container(serving_port):
622827
"""Placeholder docstring."""
623828
i = 0

src/sagemaker/local/exceptions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Custom Exceptions."""
14+
from __future__ import absolute_import
15+
16+
17+
class StepExecutionException(Exception):
18+
"""Exception indicating a failure while execution pipeline steps."""
19+
20+
def __init__(self, step_name, message):
21+
"""Placeholder docstring"""
22+
super(StepExecutionException, self).__init__(message)
23+
self.message = message
24+
self.step_name = step_name

0 commit comments

Comments
 (0)