1717"""
1818from __future__ import absolute_import
1919
20+ import abc
21+
2022from enum import Enum
2123from typing import Dict , List , Union
2224
3335from sagemaker .workflow .execution_variables import ExecutionVariable
3436from sagemaker .workflow .parameters import Parameter
3537from sagemaker .workflow .properties import Properties
38+ from sagemaker .workflow .entities import PipelineVariable
3639
3740# TODO: consider base class for those with an expr method, rather than defining a type here
3841ConditionValueType = Union [ExecutionVariable , Parameter , Properties ]
@@ -61,6 +64,11 @@ class Condition(Entity):
6164
6265 condition_type : ConditionTypeEnum = attr .ib (factory = ConditionTypeEnum .factory )
6366
67+ @property
68+ @abc .abstractmethod
69+ def _referenced_steps (self ) -> List [str ]:
70+ """List of step names that this function depends on."""
71+
6472
6573@attr .s
6674class ConditionComparison (Condition ):
@@ -84,6 +92,16 @@ def to_request(self) -> RequestType:
8492 "RightValue" : primitive_or_expr (self .right ),
8593 }
8694
95+ @property
96+ def _referenced_steps (self ) -> List [str ]:
97+ """List of step names that this function depends on."""
98+ steps = []
99+ if isinstance (self .left , PipelineVariable ):
100+ steps .extend (self .left ._referenced_steps )
101+ if isinstance (self .right , PipelineVariable ):
102+ steps .extend (self .right ._referenced_steps )
103+ return steps
104+
87105
88106class ConditionEquals (ConditionComparison ):
89107 """A condition for equality comparisons."""
@@ -213,6 +231,17 @@ def to_request(self) -> RequestType:
213231 "Values" : [primitive_or_expr (in_value ) for in_value in self .in_values ],
214232 }
215233
234+ @property
235+ def _referenced_steps (self ) -> List [str ]:
236+ """List of step names that this function depends on."""
237+ steps = []
238+ if isinstance (self .value , PipelineVariable ):
239+ steps .extend (self .value ._referenced_steps )
240+ for in_value in self .in_values :
241+ if isinstance (in_value , PipelineVariable ):
242+ steps .extend (in_value ._referenced_steps )
243+ return steps
244+
216245
217246class ConditionNot (Condition ):
218247 """A condition for negating another `Condition`."""
@@ -230,6 +259,11 @@ def to_request(self) -> RequestType:
230259 """Get the request structure for workflow service calls."""
231260 return {"Type" : self .condition_type .value , "Expression" : self .expression .to_request ()}
232261
262+ @property
263+ def _referenced_steps (self ) -> List [str ]:
264+ """List of step names that this function depends on."""
265+ return self .expression ._referenced_steps
266+
233267
234268class ConditionOr (Condition ):
235269 """A condition for taking the logical OR of a list of `Condition` instances."""
@@ -250,6 +284,14 @@ def to_request(self) -> RequestType:
250284 "Conditions" : [condition .to_request () for condition in self .conditions ],
251285 }
252286
287+ @property
288+ def _referenced_steps (self ) -> List [str ]:
289+ """List of step names that this function depends on."""
290+ steps = []
291+ for condition in self .conditions :
292+ steps .extend (condition ._referenced_steps )
293+ return steps
294+
253295
254296def primitive_or_expr (
255297 value : Union [ExecutionVariable , Expression , PrimitiveType , Parameter , Properties ]
0 commit comments