|
13 | 13 | """Placeholder docstring"""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
| 16 | +import enum |
16 | 17 | import datetime
|
17 | 18 | import json
|
18 | 19 | import logging
|
19 | 20 | import os
|
20 | 21 | import tempfile
|
21 | 22 | import time
|
| 23 | +from uuid import uuid4 |
| 24 | +from copy import deepcopy |
| 25 | +from botocore.exceptions import ClientError |
22 | 26 |
|
23 | 27 | import sagemaker.local.data
|
| 28 | + |
24 | 29 | from sagemaker.local.image import _SageMakerContainer
|
25 | 30 | from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
|
26 | 31 | from sagemaker.utils import DeferredError, get_config_value
|
| 32 | +from sagemaker.local.exceptions import StepExecutionException |
27 | 33 |
|
28 | 34 | logger = logging.getLogger(__name__)
|
29 | 35 |
|
@@ -618,6 +624,205 @@ def describe(self):
|
618 | 624 | return response
|
619 | 625 |
|
620 | 626 |
|
| 627 | +class _LocalPipeline(object): |
| 628 | + """Placeholder docstring""" |
| 629 | + |
| 630 | + _executions = {} |
| 631 | + |
| 632 | + def __init__( |
| 633 | + self, |
| 634 | + pipeline, |
| 635 | + pipeline_description=None, |
| 636 | + local_session=None, |
| 637 | + ): |
| 638 | + from sagemaker.local import LocalSession |
| 639 | + |
| 640 | + self.local_session = local_session or LocalSession() |
| 641 | + self.pipeline = pipeline |
| 642 | + self.pipeline_description = pipeline_description |
| 643 | + now_time = datetime.datetime.now() |
| 644 | + self.creation_time = now_time |
| 645 | + self.last_modified_time = now_time |
| 646 | + |
| 647 | + def describe(self): |
| 648 | + """Placeholder docstring""" |
| 649 | + response = { |
| 650 | + "PipelineArn": self.pipeline.name, |
| 651 | + "PipelineDefinition": self.pipeline.definition(), |
| 652 | + "PipelineDescription": self.pipeline_description, |
| 653 | + "PipelineName": self.pipeline.name, |
| 654 | + "PipelineStatus": "Active", |
| 655 | + "RoleArn": "<no_role>", |
| 656 | + "CreationTime": self.creation_time, |
| 657 | + "LastModifiedTime": self.last_modified_time, |
| 658 | + } |
| 659 | + return response |
| 660 | + |
| 661 | + def start(self, **kwargs): |
| 662 | + """Placeholder docstring""" |
| 663 | + from sagemaker.local.pipeline import LocalPipelineExecutor |
| 664 | + |
| 665 | + execution_id = str(uuid4()) |
| 666 | + execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs) |
| 667 | + |
| 668 | + self._executions[execution_id] = execution |
| 669 | + return LocalPipelineExecutor(execution, self.local_session).execute() |
| 670 | + |
| 671 | + |
| 672 | +class _LocalPipelineExecution(object): |
| 673 | + """Placeholder docstring""" |
| 674 | + |
| 675 | + def __init__( |
| 676 | + self, |
| 677 | + execution_id, |
| 678 | + pipeline, |
| 679 | + PipelineParameters=None, |
| 680 | + PipelineExecutionDescription=None, |
| 681 | + PipelineExecutionDisplayName=None, |
| 682 | + ): |
| 683 | + self.pipeline = pipeline |
| 684 | + self.pipeline_execution_name = execution_id |
| 685 | + self.pipeline_execution_description = PipelineExecutionDescription |
| 686 | + self.pipeline_execution_display_name = PipelineExecutionDisplayName |
| 687 | + self.status = _LocalExecutionStatus.EXECUTING.value |
| 688 | + self.failure_reason = None |
| 689 | + self.creation_time = datetime.datetime.now() |
| 690 | + self.step_execution = self._initialize_step_execution() |
| 691 | + self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters) |
| 692 | + |
| 693 | + def describe(self): |
| 694 | + """Placeholder docstring""" |
| 695 | + response = { |
| 696 | + "CreationTime": self.creation_time, |
| 697 | + "LastModifiedTime": self.creation_time, |
| 698 | + "FailureReason": self.failure_reason, |
| 699 | + "PipelineArn": self.pipeline.name, |
| 700 | + "PipelineExecutionArn": self.pipeline_execution_name, |
| 701 | + "PipelineExecutionDescription": self.pipeline_execution_description, |
| 702 | + "PipelineExecutionDisplayName": self.pipeline_execution_display_name, |
| 703 | + "PipelineExecutionStatus": self.status, |
| 704 | + } |
| 705 | + filtered_response = {k: v for k, v in response.items() if v is not None} |
| 706 | + return filtered_response |
| 707 | + |
| 708 | + def list_steps(self): |
| 709 | + """Placeholder docstring""" |
| 710 | + # TODO |
| 711 | + |
| 712 | + def update_execution_failure(self, step_name, failure_message): |
| 713 | + """Mark execution as failed.""" |
| 714 | + self.status = _LocalExecutionStatus.FAILED.value |
| 715 | + self.failure_reason = f"Step {step_name} failed with message: {failure_message}" |
| 716 | + logger.error("Pipeline execution failed because step %s failed.", step_name) |
| 717 | + |
| 718 | + def update_step_failure(self, step_name, failure_message): |
| 719 | + """Mark step_name as failed.""" |
| 720 | + self.step_execution.get(step_name).update_step_failure(failure_message) |
| 721 | + |
| 722 | + def mark_step_starting(self, step_name): |
| 723 | + """Update step's status to EXECUTING""" |
| 724 | + self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING |
| 725 | + |
| 726 | + def _initialize_step_execution(self): |
| 727 | + """Initialize step_execution dict.""" |
| 728 | + from sagemaker.workflow.steps import StepTypeEnum |
| 729 | + |
| 730 | + supported_steps_types = ( |
| 731 | + StepTypeEnum.TRAINING, |
| 732 | + StepTypeEnum.PROCESSING, |
| 733 | + StepTypeEnum.TRANSFORM, |
| 734 | + StepTypeEnum.CONDITION, |
| 735 | + StepTypeEnum.FAIL, |
| 736 | + ) |
| 737 | + |
| 738 | + step_execution = {} |
| 739 | + for step in self.pipeline.steps: |
| 740 | + if step.step_type not in supported_steps_types: |
| 741 | + error_msg = self._construct_validation_exception_message( |
| 742 | + "Step type {} is not supported in local mode.".format(step.step_type.value) |
| 743 | + ) |
| 744 | + raise ClientError(error_msg, "start_pipeline_execution") |
| 745 | + step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type) |
| 746 | + return step_execution |
| 747 | + |
| 748 | + def _initialize_and_validate_parameters(self, overridden_parameters): |
| 749 | + """Initialize and validate pipeline parameters.""" |
| 750 | + merged_parameters = {} |
| 751 | + default_parameters = {parameter.name: parameter for parameter in self.pipeline.parameters} |
| 752 | + if overridden_parameters is not None: |
| 753 | + for (param_name, param_value) in overridden_parameters.items(): |
| 754 | + if param_name not in default_parameters: |
| 755 | + error_msg = self._construct_validation_exception_message( |
| 756 | + "Unknown parameter '{}'".format(param_name) |
| 757 | + ) |
| 758 | + raise ClientError(error_msg, "start_pipeline_execution") |
| 759 | + parameter_type = default_parameters[param_name].parameter_type |
| 760 | + if type(param_value) != parameter_type.python_type: # pylint: disable=C0123 |
| 761 | + error_msg = self._construct_validation_exception_message( |
| 762 | + "Unexpected type for parameter '{}'. Expected {} but found " |
| 763 | + "{}.".format(param_name, parameter_type.python_type, type(param_value)) |
| 764 | + ) |
| 765 | + raise ClientError(error_msg, "start_pipeline_execution") |
| 766 | + merged_parameters[param_name] = param_value |
| 767 | + for param_name, default_parameter in default_parameters.items(): |
| 768 | + if param_name not in merged_parameters: |
| 769 | + if default_parameter.default_value is None: |
| 770 | + error_msg = self._construct_validation_exception_message( |
| 771 | + "Parameter '{}' is undefined.".format(param_name) |
| 772 | + ) |
| 773 | + raise ClientError(error_msg, "start_pipeline_execution") |
| 774 | + merged_parameters[param_name] = default_parameter.default_value |
| 775 | + return merged_parameters |
| 776 | + |
| 777 | + @staticmethod |
| 778 | + def _construct_validation_exception_message(exception_msg): |
| 779 | + """Construct error response for botocore.exceptions.ClientError""" |
| 780 | + return {"Error": {"Code": "ValidationException", "Message": exception_msg}} |
| 781 | + |
| 782 | + |
| 783 | +class _LocalPipelineStepExecution(object): |
| 784 | + """Placeholder docstring""" |
| 785 | + |
| 786 | + def __init__( |
| 787 | + self, |
| 788 | + step_name, |
| 789 | + step_type, |
| 790 | + last_modified_time=None, |
| 791 | + status=None, |
| 792 | + properties=None, |
| 793 | + failure_reason=None, |
| 794 | + ): |
| 795 | + self.step_name = step_name |
| 796 | + self.step_type = step_type |
| 797 | + self.status = status or _LocalExecutionStatus.STARTING |
| 798 | + self.failure_reason = failure_reason |
| 799 | + self.properties = properties or {} |
| 800 | + self.creation_time = datetime.datetime.now() |
| 801 | + self.last_modified_time = last_modified_time or self.creation_time |
| 802 | + |
| 803 | + def update_step_properties(self, properties): |
| 804 | + """Update pipeline step execution output properties.""" |
| 805 | + logger.info("Successfully completed step %s.", self.step_name) |
| 806 | + self.properties = deepcopy(properties) |
| 807 | + self.status = _LocalExecutionStatus.SUCCEEDED.value |
| 808 | + |
| 809 | + def update_step_failure(self, failure_message): |
| 810 | + """Update pipeline step execution failure status and message.""" |
| 811 | + logger.error(failure_message) |
| 812 | + self.failure_reason = failure_message |
| 813 | + self.status = _LocalExecutionStatus.FAILED.value |
| 814 | + raise StepExecutionException(self.step_name, failure_message) |
| 815 | + |
| 816 | + |
| 817 | +class _LocalExecutionStatus(enum.Enum): |
| 818 | + """Placeholder docstring""" |
| 819 | + |
| 820 | + STARTING = "Starting" |
| 821 | + EXECUTING = "Executing" |
| 822 | + SUCCEEDED = "Succeeded" |
| 823 | + FAILED = "Failed" |
| 824 | + |
| 825 | + |
621 | 826 | def _wait_for_serving_container(serving_port):
|
622 | 827 | """Placeholder docstring."""
|
623 | 828 | i = 0
|
|
0 commit comments