1515
1616import abc
1717import warnings
18+
1819from enum import Enum
1920from typing import Dict , List , Union
2021from urllib .parse import urlparse
2122
2223import attr
2324
2425from sagemaker .estimator import EstimatorBase , _TrainingJob
25- from sagemaker .inputs import (
26- CompilationInput ,
27- CreateModelInput ,
28- FileSystemInput ,
29- TrainingInput ,
30- TransformInput ,
31- )
26+ from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
3227from sagemaker .model import Model
3328from sagemaker .pipeline import PipelineModel
3429from sagemaker .processing import (
3934)
4035from sagemaker .transformer import Transformer , _TransformJob
4136from sagemaker .tuner import HyperparameterTuner , _TuningJob
42- from sagemaker .workflow .entities import DefaultEnumMeta , Entity , RequestType
37+ from sagemaker .workflow .entities import (
38+ DefaultEnumMeta ,
39+ Entity ,
40+ RequestType ,
41+ )
42+ from sagemaker .workflow .properties import (
43+ PropertyFile ,
44+ Properties ,
45+ )
4346from sagemaker .workflow .functions import Join
44- from sagemaker .workflow .properties import Properties , PropertyFile
4547from sagemaker .workflow .retry import RetryPolicy
4648
4749
@@ -56,7 +58,6 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5658 TRANSFORM = "Transform"
5759 CALLBACK = "Callback"
5860 TUNING = "Tuning"
59- COMPILATION = "Compilation"
6061 LAMBDA = "Lambda"
6162 QUALITY_CHECK = "QualityCheck"
6263 CLARIFY_CHECK = "ClarifyCheck"
@@ -730,81 +731,3 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
730731 "output/model.tar.gz" ,
731732 ],
732733 )
733-
734-
735- class CompilationStep (ConfigurableRetryStep ):
736- """Compilation step for workflow."""
737-
738- def __init__ (
739- self ,
740- name : str ,
741- estimator : EstimatorBase ,
742- model : Model ,
743- inputs : CompilationInput = None ,
744- job_arguments : List [str ] = None ,
745- depends_on : Union [List [str ], List [Step ]] = None ,
746- retry_policies : List [RetryPolicy ] = None ,
747- display_name : str = None ,
748- description : str = None ,
749- cache_config : CacheConfig = None ,
750- ):
751- """Construct a CompilationStep.
752-
753- Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
754-
755- In addition to the estimator and Model instances, the other arguments are those that are
756- supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
757-
758- Args:
759- name (str): The name of the compilation step.
760- estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
761- model (Model): A `sagemaker.model.Model` instance.
762- inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
763- Defaults to `None`.
764- job_arguments (List[str]): A list of strings to be passed into the processing job.
765- Defaults to `None`.
766- depends_on (List[str] or List[Step]): A list of step names or step instances
767- this `sagemaker.workflow.steps.CompilationStep` depends on
768- retry_policies (List[RetryPolicy]): A list of retry policy
769- display_name (str): The display name of the compilation step.
770- description (str): The description of the compilation step.
771- cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
772- """
773- super (CompilationStep , self ).__init__ (
774- name , StepTypeEnum .COMPILATION , display_name , description , depends_on , retry_policies
775- )
776- self .estimator = estimator
777- self .model = model
778- self .inputs = inputs
779- self .job_arguments = job_arguments
780- self ._properties = Properties (
781- path = f"Steps.{ name } " , shape_name = "DescribeCompilationJobResponse"
782- )
783- self .cache_config = cache_config
784-
785- @property
786- def arguments (self ) -> RequestType :
787- """The arguments dict that is used to call `create_compilation_job`.
788-
789- NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
790- The TrainingJobName and ExperimentConfig attributes cannot be included.
791- """
792-
793- compilation_args = self .model ._get_compilation_args (self .estimator , self .inputs )
794- request_dict = self .model .sagemaker_session ._get_compilation_request (** compilation_args )
795- request_dict .pop ("CompilationJobName" )
796-
797- return request_dict
798-
799- @property
800- def properties (self ):
801- """A Properties object representing the DescribeTrainingJobResponse data model."""
802- return self ._properties
803-
804- def to_request (self ) -> RequestType :
805- """Updates the dictionary with cache configuration."""
806- request_dict = super ().to_request ()
807- if self .cache_config :
808- request_dict .update (self .cache_config .config )
809-
810- return request_dict
0 commit comments