1414from __future__ import absolute_import
1515
1616import abc
17-
1817from enum import Enum
1918from typing import Dict , List , Union
2019
2120import attr
2221
2322from sagemaker .estimator import EstimatorBase , _TrainingJob
24- from sagemaker .inputs import CreateModelInput , TrainingInput , TransformInput , FileSystemInput
23+ from sagemaker .inputs import (
24+ CompilationInput ,
25+ CreateModelInput ,
26+ FileSystemInput ,
27+ TrainingInput ,
28+ TransformInput ,
29+ )
2530from sagemaker .model import Model
2631from sagemaker .processing import (
2732 ProcessingInput ,
3136)
3237from sagemaker .transformer import Transformer , _TransformJob
3338from sagemaker .tuner import HyperparameterTuner , _TuningJob
34- from sagemaker .workflow .entities import (
35- DefaultEnumMeta ,
36- Entity ,
37- RequestType ,
38- )
39- from sagemaker .workflow .properties import (
40- PropertyFile ,
41- Properties ,
42- )
39+ from sagemaker .workflow .entities import DefaultEnumMeta , Entity , RequestType
4340from sagemaker .workflow .functions import Join
41+ from sagemaker .workflow .properties import Properties , PropertyFile
4442from sagemaker .workflow .retry import RetryPolicy
4543
4644
@@ -55,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5553 TRANSFORM = "Transform"
5654 CALLBACK = "Callback"
5755 TUNING = "Tuning"
56+ COMPILATION = "Compilation"
5857 LAMBDA = "Lambda"
5958
6059
@@ -681,3 +680,81 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
681680 "output/model.tar.gz" ,
682681 ],
683682 )
683+
684+
685+ class CompilationStep (ConfigurableRetryStep ):
686+ """Compilation step for workflow."""
687+
688+ def __init__ (
689+ self ,
690+ name : str ,
691+ estimator : EstimatorBase ,
692+ model : Model ,
693+ inputs : CompilationInput = None ,
694+ job_arguments : List [str ] = None ,
695+ depends_on : Union [List [str ], List [Step ]] = None ,
696+ retry_policies : List [RetryPolicy ] = None ,
697+ display_name : str = None ,
698+ description : str = None ,
699+ cache_config : CacheConfig = None ,
700+ ):
701+ """Construct a CompilationStep.
702+
703+ Given an `EstimatorBase` and a `sagemaker.model.Model` instance construct a CompilationStep.
704+
705+ In addition to the estimator and Model instances, the other arguments are those that are
706+ supplied to the `compile_model` method of the `sagemaker.model.Model.compile_model`.
707+
708+ Args:
709+ name (str): The name of the compilation step.
710+ estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
711+ model (Model): A `sagemaker.model.Model` instance.
712+ inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
713+ Defaults to `None`.
714+ job_arguments (List[str]): A list of strings to be passed into the processing job.
715+ Defaults to `None`.
716+ depends_on (List[str] or List[Step]): A list of step names or step instances
717+ this `sagemaker.workflow.steps.CompilationStep` depends on
718+ retry_policies (List[RetryPolicy]): A list of retry policy
719+ display_name (str): The display name of the compilation step.
720+ description (str): The description of the compilation step.
721+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
722+ """
723+ super (CompilationStep , self ).__init__ (
724+ name , StepTypeEnum .COMPILATION , display_name , description , depends_on , retry_policies
725+ )
726+ self .estimator = estimator
727+ self .model = model
728+ self .inputs = inputs
729+ self .job_arguments = job_arguments
730+ self ._properties = Properties (
731+ path = f"Steps.{ name } " , shape_name = "DescribeCompilationJobResponse"
732+ )
733+ self .cache_config = cache_config
734+
735+ @property
736+ def arguments (self ) -> RequestType :
737+ """The arguments dict that is used to call `create_compilation_job`.
738+
739+ NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
740+ The TrainingJobName and ExperimentConfig attributes cannot be included.
741+ """
742+
743+ compilation_args = self .model ._get_compilation_args (self .estimator , self .inputs )
744+ request_dict = self .model .sagemaker_session ._get_compilation_request (** compilation_args )
745+ request_dict .pop ("CompilationJobName" )
746+
747+ return request_dict
748+
749+ @property
750+ def properties (self ):
751+ """A Properties object representing the DescribeTrainingJobResponse data model."""
752+ return self ._properties
753+
754+ def to_request (self ) -> RequestType :
755+ """Updates the dictionary with cache configuration."""
756+ request_dict = super ().to_request ()
757+ if self .cache_config :
758+ request_dict .update (self .cache_config .config )
759+
760+ return request_dict
0 commit comments