Skip to content

Commit fa32de4

Browse files
nmadanNamrata Madan
authored andcommitted
change: implement local JsonGet function
Co-authored-by: Namrata Madan <[email protected]>
1 parent 243c3ae commit fa32de4

File tree

12 files changed

+851
-172
lines changed

12 files changed

+851
-172
lines changed

src/sagemaker/local/entities.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def describe(self):
625625

626626

627627
class _LocalPipeline(object):
628-
"""Placeholder docstring"""
628+
"""Class representing a local SageMaker Pipeline"""
629629

630630
_executions = {}
631631

@@ -645,7 +645,7 @@ def __init__(
645645
self.last_modified_time = now_time
646646

647647
def describe(self):
648-
"""Placeholder docstring"""
648+
"""Describe Pipeline"""
649649
response = {
650650
"PipelineArn": self.pipeline.name,
651651
"PipelineDefinition": self.pipeline.definition(),
@@ -659,7 +659,7 @@ def describe(self):
659659
return response
660660

661661
def start(self, **kwargs):
662-
"""Placeholder docstring"""
662+
"""Start a pipeline execution. Returns a _LocalPipelineExecution object."""
663663
from sagemaker.local.pipeline import LocalPipelineExecutor
664664

665665
execution_id = str(uuid4())
@@ -670,7 +670,7 @@ def start(self, **kwargs):
670670

671671

672672
class _LocalPipelineExecution(object):
673-
"""Placeholder docstring"""
673+
"""Class representing a local SageMaker pipeline execution."""
674674

675675
def __init__(
676676
self,
@@ -693,7 +693,7 @@ def __init__(
693693
self.blockout_steps = {}
694694

695695
def describe(self):
696-
"""Placeholder docstring"""
696+
"""Describe Pipeline Execution."""
697697
response = {
698698
"CreationTime": self.creation_time,
699699
"LastModifiedTime": self.creation_time,
@@ -708,8 +708,14 @@ def describe(self):
708708
return filtered_response
709709

710710
def list_steps(self):
711-
"""Placeholder docstring"""
712-
# TODO
711+
"""List pipeline execution steps."""
712+
return {
713+
"PipelineExecutionSteps": [
714+
step.to_list_steps_response()
715+
for step in self.step_execution.values()
716+
if step.status is not None
717+
]
718+
}
713719

714720
def update_execution_success(self):
715721
"""Mark execution as succeeded."""
@@ -730,8 +736,8 @@ def update_step_failure(self, step_name, failure_message):
730736
self.step_execution.get(step_name).update_step_failure(failure_message)
731737

732738
def mark_step_executing(self, step_name):
733-
"""Update step's status to EXECUTING"""
734-
self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING.value
739+
"""Update pipelines step's status to EXECUTING and start_time to now."""
740+
self.step_execution.get(step_name).mark_step_executing()
735741

736742
def _initialize_step_execution(self, steps):
737743
"""Initialize step_execution dict."""
@@ -751,7 +757,9 @@ def _initialize_step_execution(self, steps):
751757
"Step type {} is not supported in local mode.".format(step.step_type.value)
752758
)
753759
raise ClientError(error_msg, "start_pipeline_execution")
754-
self.step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type)
760+
self.step_execution[step.name] = _LocalPipelineExecutionStep(
761+
step.name, step.step_type, step.description, step.display_name
762+
)
755763
if step.step_type == StepTypeEnum.CONDITION:
756764
self._initialize_step_execution(step.if_steps + step.else_steps)
757765

@@ -790,44 +798,105 @@ def _construct_validation_exception_message(exception_msg):
790798
return {"Error": {"Code": "ValidationException", "Message": exception_msg}}
791799

792800

793-
class _LocalPipelineStepExecution(object):
794-
"""Placeholder docstring"""
801+
class _LocalPipelineExecutionStep(object):
802+
"""Class representing a local pipeline execution step."""
795803

796804
def __init__(
797805
self,
798-
step_name,
806+
name,
799807
step_type,
800-
last_modified_time=None,
808+
description,
809+
display_name=None,
810+
start_time=None,
811+
end_time=None,
801812
status=None,
802813
properties=None,
803814
failure_reason=None,
804815
):
805-
self.step_name = step_name
806-
self.step_type = step_type
807-
self.status = status or _LocalExecutionStatus.STARTING.value
816+
from sagemaker.workflow.steps import StepTypeEnum
817+
818+
self.name = name
819+
self.type = step_type
820+
self.description = description
821+
self.display_name = display_name
822+
self.status = status
808823
self.failure_reason = failure_reason
809824
self.properties = properties or {}
810-
self.creation_time = datetime.datetime.now()
811-
self.last_modified_time = last_modified_time or self.creation_time
825+
self.start_time = start_time
826+
self.end_time = end_time
827+
self._step_type_to_output_format_map = {
828+
StepTypeEnum.TRAINING: self._construct_training_metadata,
829+
StepTypeEnum.PROCESSING: self._construct_processing_metadata,
830+
StepTypeEnum.TRANSFORM: self._construct_transform_metadata,
831+
StepTypeEnum.CONDITION: self._construct_condition_metadata,
832+
StepTypeEnum.FAIL: self._construct_fail_metadata,
833+
}
812834

813835
def update_step_properties(self, properties):
814836
"""Update pipeline step execution output properties."""
815-
logger.info("Successfully completed step %s.", self.step_name)
837+
logger.info("Successfully completed step %s.", self.name)
816838
self.properties = deepcopy(properties)
817839
self.status = _LocalExecutionStatus.SUCCEEDED.value
840+
self.end_time = datetime.datetime.now()
818841

819842
def update_step_failure(self, failure_message):
820843
"""Update pipeline step execution failure status and message."""
821844
logger.error(failure_message)
822845
self.failure_reason = failure_message
823846
self.status = _LocalExecutionStatus.FAILED.value
824-
raise StepExecutionException(self.step_name, failure_message)
847+
self.end_time = datetime.datetime.now()
848+
raise StepExecutionException(self.name, failure_message)
849+
850+
def mark_step_executing(self):
851+
"""Update pipelines step's status to EXECUTING and start_time to now"""
852+
self.status = _LocalExecutionStatus.EXECUTING.value
853+
self.start_time = datetime.datetime.now()
854+
855+
def to_list_steps_response(self):
856+
"""Convert to response dict for list_steps calls."""
857+
response = {
858+
"EndTime": self.end_time,
859+
"FailureReason": self.failure_reason,
860+
"Metadata": self._construct_metadata(),
861+
"StartTime": self.start_time,
862+
"StepDescription": self.description,
863+
"StepDisplayName": self.display_name,
864+
"StepName": self.name,
865+
"StepStatus": self.status,
866+
}
867+
filtered_response = {k: v for k, v in response.items() if v is not None}
868+
return filtered_response
869+
870+
def _construct_metadata(self):
871+
"""Constructs the metadata shape for the list_steps_response."""
872+
if self.properties:
873+
return self._step_type_to_output_format_map[self.type]()
874+
return None
875+
876+
def _construct_training_metadata(self):
877+
"""Construct training job metadata response."""
878+
return {"TrainingJob": {"Arn": self.properties.TrainingJobArn}}
879+
880+
def _construct_processing_metadata(self):
881+
"""Construct processing job metadata response."""
882+
return {"ProcessingJob": {"Arn": self.properties.ProcessingJobArn}}
883+
884+
def _construct_transform_metadata(self):
885+
"""Construct transform job metadata response."""
886+
return {"TransformJob": {"Arn": self.properties.TransformJobArn}}
887+
888+
def _construct_condition_metadata(self):
889+
"""Construct condition step metadata response."""
890+
return {"Condition": {"Outcome": self.properties.Outcome}}
891+
892+
def _construct_fail_metadata(self):
893+
"""Construct fail step metadata response."""
894+
return {"Fail": {"ErrorMessage": self.properties.ErrorMessage}}
825895

826896

827897
class _LocalExecutionStatus(enum.Enum):
828-
"""Placeholder docstring"""
898+
"""Pipeline execution status."""
829899

830-
STARTING = "Starting"
831900
EXECUTING = "Executing"
832901
SUCCEEDED = "Succeeded"
833902
FAILED = "Failed"

src/sagemaker/local/local_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ class LocalSession(Session):
595595
:class:`~sagemaker.session.Session`.
596596
"""
597597

598-
def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False):
598+
def __init__(
599+
self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False
600+
):
599601
"""Create a Local SageMaker Session.
600602
601603
Args:
@@ -614,7 +616,7 @@ def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=F
614616
# discourage external use:
615617
self._disable_local_code = disable_local_code
616618

617-
super(LocalSession, self).__init__(boto_session)
619+
super(LocalSession, self).__init__(boto_session=boto_session, default_bucket=default_bucket)
618620

619621
if platform.system() == "Windows":
620622
logger.warning("Windows Support for Local Mode is Experimental")
@@ -718,9 +720,12 @@ def __init__(self, fileUri, content_type=None):
718720
class LocalPipelineSession(LocalSession):
719721
"""Class representing a local session for SageMaker Pipelines executions."""
720722

721-
def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False):
723+
def __init__(
724+
self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False
725+
):
722726
super().__init__(
723727
boto_session=boto_session,
728+
default_bucket=default_bucket,
724729
s3_endpoint_url=s3_endpoint_url,
725730
disable_local_code=disable_local_code,
726731
)

src/sagemaker/local/pipeline.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
from abc import ABC, abstractmethod
1616

1717
import logging
18+
import json
1819
from copy import deepcopy
1920
from datetime import datetime
2021
from typing import Dict, List
21-
from sagemaker.workflow.conditions import ConditionTypeEnum
22+
from botocore.exceptions import ClientError
2223

24+
from sagemaker.workflow.conditions import ConditionTypeEnum
2325
from sagemaker.workflow.steps import StepTypeEnum, Step
2426
from sagemaker.workflow.entities import PipelineVariable
2527
from sagemaker.workflow.parameters import Parameter
26-
from sagemaker.workflow.functions import Join, JsonGet
28+
from sagemaker.workflow.functions import Join, JsonGet, PropertyFile
2729
from sagemaker.workflow.properties import Properties
2830
from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables
2931
from sagemaker.workflow.pipeline import PipelineGraph
@@ -116,8 +118,7 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name):
116118
elif isinstance(pipeline_variable, ExecutionVariable):
117119
value = self._evaluate_execution_variable(pipeline_variable)
118120
elif isinstance(pipeline_variable, JsonGet):
119-
# TODO
120-
raise NotImplementedError
121+
value = self._evaluate_json_get_function(pipeline_variable, step_name)
121122
else:
122123
self.execution.update_step_failure(
123124
step_name, f"Unrecognized pipeline variable {pipeline_variable.expr}."
@@ -133,7 +134,7 @@ def _evaluate_property_reference(self, pipeline_variable, step_name):
133134
referenced_step_name = pipeline_variable.step_name
134135
step_properties = self.execution.step_execution.get(referenced_step_name).properties
135136
return get_using_dot_notation(step_properties, pipeline_variable.path)
136-
except (KeyError, IndexError):
137+
except (KeyError, IndexError, TypeError):
137138
self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.")
138139

139140
def _evaluate_execution_variable(self, pipeline_variable):
@@ -154,6 +155,56 @@ def _evaluate_execution_variable(self, pipeline_variable):
154155
return datetime.now()
155156
return None
156157

158+
def _evaluate_json_get_function(self, pipeline_variable, step_name):
159+
"""Evaluate join function runtime value."""
160+
property_file_reference = pipeline_variable.property_file
161+
property_file = None
162+
if isinstance(property_file_reference, str):
163+
processing_step = self.pipeline_dag.step_map[pipeline_variable.step_name]
164+
for file in processing_step.property_files:
165+
if file.name == property_file_reference:
166+
property_file = file
167+
break
168+
elif isinstance(property_file_reference, PropertyFile):
169+
property_file = property_file_reference
170+
processing_step_response = self.execution.step_execution.get(
171+
pipeline_variable.step_name
172+
).properties
173+
if (
174+
"ProcessingOutputConfig" not in processing_step_response
175+
or "Outputs" not in processing_step_response["ProcessingOutputConfig"]
176+
):
177+
self.execution.update_step_failure(
178+
step_name,
179+
f"Step '{pipeline_variable.step_name}' does not yet contain processing outputs.",
180+
)
181+
processing_output_s3_bucket = None
182+
for output in processing_step_response["ProcessingOutputConfig"]["Outputs"]:
183+
if output["OutputName"] == property_file.output_name:
184+
processing_output_s3_bucket = output["S3Output"]["S3Uri"]
185+
break
186+
try:
187+
file_content = self.sagemaker_session.read_s3_file(
188+
processing_output_s3_bucket, property_file.path
189+
)
190+
file_json = json.loads(file_content)
191+
return get_using_dot_notation(file_json, pipeline_variable.json_path)
192+
except ClientError as e:
193+
self.execution.update_step_failure(
194+
step_name,
195+
f"Received an error while file reading file '{property_file.path}' from S3: "
196+
f"{e.response.get('Code')}: {e.response.get('Message')}",
197+
)
198+
except json.JSONDecodeError:
199+
self.execution.update_step_failure(
200+
step_name,
201+
f"Contents of property file '{property_file.name}' are not in valid JSON format.",
202+
)
203+
except (KeyError, IndexError, TypeError):
204+
self.execution.update_step_failure(
205+
step_name, f"Invalid json path '{pipeline_variable.json_path}'"
206+
)
207+
157208

158209
class _StepExecutor(ABC):
159210
"""An abstract base class for step executors running steps locally"""

src/sagemaker/local/utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_using_dot_notation(dictionary, keys):
166166
Nested object within dictionary as defined by "keys"
167167
168168
Raises:
169-
KeyError or IndexError if the provided key does not exist in input dictionary
169+
KeyError/IndexError/TypeError if the provided key does not exist in input dictionary
170170
"""
171171
if keys is None:
172172
return dictionary
@@ -175,14 +175,23 @@ def get_using_dot_notation(dictionary, keys):
175175
rest = None
176176
if len(split_keys) > 1:
177177
rest = split_keys[1]
178-
list_accessor = re.search(r"(\w+)\[(\d+)]", key)
179-
if list_accessor:
180-
key = list_accessor.group(1)
181-
list_index = int(list_accessor.group(2))
182-
return get_using_dot_notation(dictionary[key][list_index], rest)
183-
dict_accessor = re.search(r"(\w+)\[['\"](\S+)['\"]]", key)
184-
if dict_accessor:
185-
key = dict_accessor.group(1)
186-
inner_key = dict_accessor.group(2)
187-
return get_using_dot_notation(dictionary[key][inner_key], rest)
188-
return get_using_dot_notation(dictionary[key], rest)
178+
bracket_accessors = re.findall(r"\[(.+?)]", key)
179+
if bracket_accessors:
180+
pre_bracket_key = key.split("[", 1)[0]
181+
inner_dict = dictionary[pre_bracket_key]
182+
else:
183+
inner_dict = dictionary[key]
184+
for bracket_accessor in bracket_accessors:
185+
if (
186+
bracket_accessor.startswith("'")
187+
and bracket_accessor.endswith("'")
188+
or bracket_accessor.startswith('"')
189+
and bracket_accessor.endswith('"')
190+
):
191+
# key accessor
192+
inner_key = bracket_accessor[1:-1]
193+
else:
194+
# list accessor
195+
inner_key = int(bracket_accessor)
196+
inner_dict = inner_dict[inner_key]
197+
return get_using_dot_notation(inner_dict, rest)

0 commit comments

Comments
 (0)