Skip to content

Commit 6e8939c

Browse files
nmadanNamrata Madan
authored andcommitted
change: add local mode integ tests
Co-authored-by: Namrata Madan <[email protected]>
1 parent fa32de4 commit 6e8939c

File tree

13 files changed

+452
-123
lines changed

13 files changed

+452
-123
lines changed

src/sagemaker/local/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,4 @@
1818
LocalSagemakerClient,
1919
LocalSagemakerRuntimeClient,
2020
LocalSession,
21-
LocalPipelineSession,
2221
)

src/sagemaker/local/entities.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(self, container):
200200
self.start_time = None
201201
self.end_time = None
202202
self.environment = None
203+
self.training_job_name = ""
203204

204205
def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
205206
"""Starts a local training job.
@@ -244,10 +245,13 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm
244245
)
245246
self.end_time = datetime.datetime.now()
246247
self.state = self._COMPLETED
248+
self.training_job_name = job_name
247249

248250
def describe(self):
249251
"""Placeholder docstring"""
250252
response = {
253+
"TrainingJobName": self.training_job_name,
254+
"TrainingJobArn": _UNUSED_ARN,
251255
"ResourceConfig": {"InstanceCount": self.container.instance_count},
252256
"TrainingJobStatus": self.state,
253257
"TrainingStartTime": self.start_time,
@@ -640,9 +644,8 @@ def __init__(
640644
self.local_session = local_session or LocalSession()
641645
self.pipeline = pipeline
642646
self.pipeline_description = pipeline_description
643-
now_time = datetime.datetime.now()
644-
self.creation_time = now_time
645-
self.last_modified_time = now_time
647+
self.creation_time = datetime.datetime.now().timestamp()
648+
self.last_modified_time = self.creation_time
646649

647650
def describe(self):
648651
"""Describe Pipeline"""
@@ -666,6 +669,13 @@ def start(self, **kwargs):
666669
execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs)
667670

668671
self._executions[execution_id] = execution
672+
logger.info(
673+
"Starting execution for pipeline %s. Execution ID is %s",
674+
self.pipeline.name,
675+
execution_id,
676+
)
677+
self.last_modified_time = datetime.datetime.now().timestamp()
678+
669679
return LocalPipelineExecutor(execution, self.local_session).execute()
670680

671681

@@ -686,17 +696,18 @@ def __init__(
686696
self.pipeline_execution_display_name = PipelineExecutionDisplayName
687697
self.status = _LocalExecutionStatus.EXECUTING.value
688698
self.failure_reason = None
689-
self.creation_time = datetime.datetime.now()
699+
self.creation_time = datetime.datetime.now().timestamp()
700+
self.last_modified_time = self.creation_time
690701
self.step_execution = {}
691702
self._initialize_step_execution(self.pipeline.steps)
692703
self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters)
693-
self.blockout_steps = {}
704+
self._blocked_steps = {}
694705

695706
def describe(self):
696707
"""Describe Pipeline Execution."""
697708
response = {
698709
"CreationTime": self.creation_time,
699-
"LastModifiedTime": self.creation_time,
710+
"LastModifiedTime": self.last_modified_time,
700711
"FailureReason": self.failure_reason,
701712
"PipelineArn": self.pipeline.name,
702713
"PipelineExecutionArn": self.pipeline_execution_name,
@@ -720,23 +731,33 @@ def list_steps(self):
720731
def update_execution_success(self):
721732
"""Mark execution as succeeded."""
722733
self.status = _LocalExecutionStatus.SUCCEEDED.value
734+
self.last_modified_time = datetime.datetime.now().timestamp()
735+
logger.info("Pipeline execution %s SUCCEEDED", self.pipeline_execution_name)
723736

724737
def update_execution_failure(self, step_name, failure_message):
725738
"""Mark execution as failed."""
726739
self.status = _LocalExecutionStatus.FAILED.value
727740
self.failure_reason = f"Step {step_name} failed with message: {failure_message}"
728-
logger.error("Pipeline execution failed because step %s failed.", step_name)
741+
self.last_modified_time = datetime.datetime.now().timestamp()
742+
logger.info(
743+
"Pipeline execution %s FAILED because step %s failed.",
744+
self.pipeline_execution_name,
745+
step_name,
746+
)
729747

730748
def update_step_properties(self, step_name, step_properties):
731749
"""Update pipeline step execution output properties."""
732750
self.step_execution.get(step_name).update_step_properties(step_properties)
751+
logger.info("Pipeline step %s SUCCEEDED.", step_name)
733752

734753
def update_step_failure(self, step_name, failure_message):
735754
"""Mark step_name as failed."""
736755
self.step_execution.get(step_name).update_step_failure(failure_message)
756+
logger.info("Pipeline step %s FAILED. Failure message is: %s", step_name, failure_message)
737757

738758
def mark_step_executing(self, step_name):
739759
"""Update pipelines step's status to EXECUTING and start_time to now."""
760+
logger.info("Starting pipeline step: %s", step_name)
740761
self.step_execution.get(step_name).mark_step_executing()
741762

742763
def _initialize_step_execution(self, steps):
@@ -749,6 +770,7 @@ def _initialize_step_execution(self, steps):
749770
StepTypeEnum.TRANSFORM,
750771
StepTypeEnum.CONDITION,
751772
StepTypeEnum.FAIL,
773+
StepTypeEnum.CREATE_MODEL,
752774
)
753775

754776
for step in steps:
@@ -828,29 +850,28 @@ def __init__(
828850
StepTypeEnum.TRAINING: self._construct_training_metadata,
829851
StepTypeEnum.PROCESSING: self._construct_processing_metadata,
830852
StepTypeEnum.TRANSFORM: self._construct_transform_metadata,
853+
StepTypeEnum.CREATE_MODEL: self._construct_model_metadata,
831854
StepTypeEnum.CONDITION: self._construct_condition_metadata,
832855
StepTypeEnum.FAIL: self._construct_fail_metadata,
833856
}
834857

835858
def update_step_properties(self, properties):
836859
"""Update pipeline step execution output properties."""
837-
logger.info("Successfully completed step %s.", self.name)
838860
self.properties = deepcopy(properties)
839861
self.status = _LocalExecutionStatus.SUCCEEDED.value
840-
self.end_time = datetime.datetime.now()
862+
self.end_time = datetime.datetime.now().timestamp()
841863

842864
def update_step_failure(self, failure_message):
843865
"""Update pipeline step execution failure status and message."""
844-
logger.error(failure_message)
845866
self.failure_reason = failure_message
846867
self.status = _LocalExecutionStatus.FAILED.value
847-
self.end_time = datetime.datetime.now()
868+
self.end_time = datetime.datetime.now().timestamp()
848869
raise StepExecutionException(self.name, failure_message)
849870

850871
def mark_step_executing(self):
851872
"""Update pipelines step's status to EXECUTING and start_time to now"""
852873
self.status = _LocalExecutionStatus.EXECUTING.value
853-
self.start_time = datetime.datetime.now()
874+
self.start_time = datetime.datetime.now().timestamp()
854875

855876
def to_list_steps_response(self):
856877
"""Convert to response dict for list_steps calls."""
@@ -875,23 +896,27 @@ def _construct_metadata(self):
875896

876897
def _construct_training_metadata(self):
877898
"""Construct training job metadata response."""
878-
return {"TrainingJob": {"Arn": self.properties.TrainingJobArn}}
899+
return {"TrainingJob": {"Arn": self.properties["TrainingJobName"]}}
879900

880901
def _construct_processing_metadata(self):
881902
"""Construct processing job metadata response."""
882-
return {"ProcessingJob": {"Arn": self.properties.ProcessingJobArn}}
903+
return {"ProcessingJob": {"Arn": self.properties["ProcessingJobName"]}}
883904

884905
def _construct_transform_metadata(self):
885906
"""Construct transform job metadata response."""
886-
return {"TransformJob": {"Arn": self.properties.TransformJobArn}}
907+
return {"TransformJob": {"Arn": self.properties["TransformJobName"]}}
908+
909+
def _construct_model_metadata(self):
910+
"""Construct create model step metadata response."""
911+
return {"Model": {"Arn": self.properties["ModelName"]}}
887912

888913
def _construct_condition_metadata(self):
889914
"""Construct condition step metadata response."""
890-
return {"Condition": {"Outcome": self.properties.Outcome}}
915+
return {"Condition": {"Outcome": self.properties["Outcome"]}}
891916

892917
def _construct_fail_metadata(self):
893918
"""Construct fail step metadata response."""
894-
return {"Fail": {"ErrorMessage": self.properties.ErrorMessage}}
919+
return {"Fail": {"ErrorMessage": self.properties["ErrorMessage"]}}
895920

896921

897922
class _LocalExecutionStatus(enum.Enum):

src/sagemaker/local/local_session.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,9 @@ def update_pipeline(
448448
}
449449
raise ClientError(error_response, "update_pipeline")
450450
LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description
451-
LocalSagemakerClient._pipelines[pipeline.name].last_modified_time = datetime.now()
451+
LocalSagemakerClient._pipelines[
452+
pipeline.name
453+
].last_modified_time = datetime.now().timestamp()
452454
return {"PipelineArn": pipeline.name}
453455

454456
def describe_pipeline(self, PipelineName):
@@ -715,17 +717,3 @@ def __init__(self, fileUri, content_type=None):
715717

716718
if content_type is not None:
717719
self.config["ContentType"] = content_type
718-
719-
720-
class LocalPipelineSession(LocalSession):
721-
"""Class representing a local session for SageMaker Pipelines executions."""
722-
723-
def __init__(
724-
self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False
725-
):
726-
super().__init__(
727-
boto_session=boto_session,
728-
default_bucket=default_bucket,
729-
s3_endpoint_url=s3_endpoint_url,
730-
disable_local_code=disable_local_code,
731-
)

0 commit comments

Comments
 (0)