@@ -92,6 +92,38 @@ def ref(self) -> Dict[str, str]:
9292 return {"Name" : self .name }
9393
9494
95+ @attr .s
96+ class CacheConfig :
97+ """Configuration class to enable caching in pipeline workflow.
98+
99+ If caching is enabled, the pipeline attempts to find a previous execution of a step
100+ that was called with the same arguments. Step caching only considers successful execution.
101+ If a successful previous execution is found, the pipeline propagates the values
102+ from previous execution rather than recomputing the step. When multiple successful executions
103+ exist within the timeout period, it uses the result for the most recent successful execution.
104+
105+
106+ Attributes:
107+ enable_caching (bool): To enable step caching. Defaults to `False`.
108+ expire_after (str): If step caching is enabled, a timeout also needs to defined.
109+ It defines how old a previous execution can be to be considered for reuse.
110+ Value should be an ISO 8601 duration string. Defaults to `None`.
111+ """
112+
113+ enable_caching : bool = attr .ib (default = False )
114+ expire_after = attr .ib (
115+ default = None , validator = attr .validators .optional (attr .validators .instance_of (str ))
116+ )
117+
118+ @property
119+ def config (self ):
120+ """Configures caching in pipeline steps."""
121+ config = {"Enabled" : self .enable_caching }
122+ if self .expire_after is not None :
123+ config ["ExpireAfter" ] = self .expire_after
124+ return {"CacheConfig" : config }
125+
126+
95127class TrainingStep (Step ):
96128 """Training step for workflow."""
97129
@@ -100,6 +132,7 @@ def __init__(
100132 name : str ,
101133 estimator : EstimatorBase ,
102134 inputs : TrainingInput = None ,
135+ cache_config : CacheConfig = None ,
103136 ):
104137 """Construct a TrainingStep, given an `EstimatorBase` instance.
105138
@@ -110,14 +143,15 @@ def __init__(
110143 name (str): The name of the training step.
111144 estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
112145 inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
113147 """
114148 super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING )
115149 self .estimator = estimator
116150 self .inputs = inputs
117-
118151 self ._properties = Properties (
119152 path = f"Steps.{ name } " , shape_name = "DescribeTrainingJobResponse"
120153 )
154+ self .cache_config = cache_config
121155
122156 @property
123157 def arguments (self ) -> RequestType :
@@ -144,6 +178,14 @@ def properties(self):
144178 """A Properties object representing the DescribeTrainingJobResponse data model."""
145179 return self ._properties
146180
181+ def to_request (self ) -> RequestType :
182+ """Updates the dictionary with cache configuration."""
183+ request_dict = super ().to_request ()
184+ if self .cache_config :
185+ request_dict .update (self .cache_config .config )
186+
187+ return request_dict
188+
147189
148190class CreateModelStep (Step ):
149191 """CreateModel step for workflow."""
@@ -207,6 +249,7 @@ def __init__(
207249 name : str ,
208250 transformer : Transformer ,
209251 inputs : TransformInput ,
252+ cache_config : CacheConfig = None ,
210253 ):
211254 """Constructs a TransformStep, given an `Transformer` instance.
212255
@@ -217,11 +260,12 @@ def __init__(
217260 name (str): The name of the transform step.
218261 transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
219262 inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
263+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
220264 """
221265 super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM )
222266 self .transformer = transformer
223267 self .inputs = inputs
224-
268+ self . cache_config = cache_config
225269 self ._properties = Properties (
226270 path = f"Steps.{ name } " , shape_name = "DescribeTransformJobResponse"
227271 )
@@ -257,6 +301,14 @@ def properties(self):
257301 """A Properties object representing the DescribeTransformJobResponse data model."""
258302 return self ._properties
259303
304+ def to_request (self ) -> RequestType :
305+ """Updates the dictionary with cache configuration."""
306+ request_dict = super ().to_request ()
307+ if self .cache_config :
308+ request_dict .update (self .cache_config .config )
309+
310+ return request_dict
311+
260312
261313class ProcessingStep (Step ):
262314 """Processing step for workflow."""
@@ -270,6 +322,7 @@ def __init__(
270322 job_arguments : List [str ] = None ,
271323 code : str = None ,
272324 property_files : List [PropertyFile ] = None ,
325+ cache_config : CacheConfig = None ,
273326 ):
274327 """Construct a ProcessingStep, given a `Processor` instance.
275328
@@ -289,6 +342,7 @@ def __init__(
289342 script to run. Defaults to `None`.
290343 property_files (List[PropertyFile]): A list of property files that workflow looks
291344 for and resolves from the configured processing output list.
345+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
292346 """
293347 super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING )
294348 self .processor = processor
@@ -305,6 +359,7 @@ def __init__(
305359 self ._properties = Properties (
306360 path = f"Steps.{ name } " , shape_name = "DescribeProcessingJobResponse"
307361 )
362+ self .cache_config = cache_config
308363
309364 @property
310365 def arguments (self ) -> RequestType :
@@ -335,6 +390,8 @@ def properties(self):
335390 def to_request (self ) -> RequestType :
336391 """Get the request structure for workflow service calls."""
337392 request_dict = super (ProcessingStep , self ).to_request ()
393+ if self .cache_config :
394+ request_dict .update (self .cache_config .config )
338395 if self .property_files :
339396 request_dict ["PropertyFiles" ] = [
340397 property_file .expr for property_file in self .property_files
0 commit comments