@@ -60,12 +60,13 @@ class Step(Entity):
6060 Attributes:
6161 name (str): The name of the step.
6262 step_type (StepTypeEnum): The type of the step.
63- depends_on (List[str]): The list of step names the current step depends on
63+ depends_on (List[str] or List[Step]): The list of step names or step
64+ instances the current step depends on
6465 """
6566
6667 name : str = attr .ib (factory = str )
6768 step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
68- depends_on : List [str ] = attr .ib (default = None )
69+ depends_on : Union [ List [str ], List [ "Step" ] ] = attr .ib (default = None )
6970
7071 @property
7172 @abc .abstractmethod
@@ -85,11 +86,13 @@ def to_request(self) -> RequestType:
8586 "Arguments" : self .arguments ,
8687 }
8788 if self .depends_on :
88- request_dict ["DependsOn" ] = self .depends_on
89+ request_dict ["DependsOn" ] = self ._resolve_depends_on (self .depends_on )
90+
8991 return request_dict
9092
91- def add_depends_on (self , step_names : List [str ]):
92- """Add step names to the current step depends on list"""
93+ def add_depends_on (self , step_names : Union [List [str ], List ["Step" ]]):
94+ """Add step names or step instances to the current step depends on list"""
95+
9396 if not step_names :
9497 return
9598 if not self .depends_on :
@@ -101,6 +104,19 @@ def ref(self) -> Dict[str, str]:
101104 """Gets a reference dict for steps"""
102105 return {"Name" : self .name }
103106
107+ @staticmethod
108+ def _resolve_depends_on (depends_on_list : Union [List [str ], List ["Step" ]]):
109+ """Resolver the step depends on list"""
110+ depends_on = []
111+ for step in depends_on_list :
112+ if isinstance (step , Step ):
113+ depends_on .append (step .name )
114+ elif isinstance (step , str ):
115+ depends_on .append (step )
116+ else :
117+ raise ValueError (f"Invalid input step name: { step } " )
118+ return depends_on
119+
104120
105121@attr .s
106122class CacheConfig :
@@ -143,7 +159,7 @@ def __init__(
143159 estimator : EstimatorBase ,
144160 inputs : Union [TrainingInput , dict , str , FileSystemInput ] = None ,
145161 cache_config : CacheConfig = None ,
146- depends_on : List [str ] = None ,
162+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
147163 ):
148164 """Construct a TrainingStep, given an `EstimatorBase` instance.
149165
@@ -171,8 +187,8 @@ def __init__(
171187 the path to the training dataset.
172188
173189 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
174- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
175- depends on
190+ depends_on (List[str] or List[Step] ): A list of step names or step instances
191+ this `sagemaker.workflow.steps.TrainingStep` depends on
176192 """
177193 super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
178194 self .estimator = estimator
@@ -217,7 +233,11 @@ class CreateModelStep(Step):
217233 """CreateModel step for workflow."""
218234
219235 def __init__ (
220- self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
236+ self ,
237+ name : str ,
238+ model : Model ,
239+ inputs : CreateModelInput ,
240+ depends_on : Union [List [str ], List [Step ]] = None ,
221241 ):
222242 """Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
223243
@@ -229,8 +249,8 @@ def __init__(
229249 model (Model): A `sagemaker.model.Model` instance.
230250 inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
231251 Defaults to `None`.
232- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
233- depends on
252+ depends_on (List[str] or List[Step] ): A list of step names or step instances
253+ this `sagemaker.workflow.steps.CreateModelStep` depends on
234254 """
235255 super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
236256 self .model = model
@@ -275,7 +295,7 @@ def __init__(
275295 transformer : Transformer ,
276296 inputs : TransformInput ,
277297 cache_config : CacheConfig = None ,
278- depends_on : List [str ] = None ,
298+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
279299 ):
280300 """Constructs a TransformStep, given an `Transformer` instance.
281301
@@ -287,8 +307,8 @@ def __init__(
287307 transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
288308 inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
289309 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
290- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
291- depends on
310+ depends_on (List[str] or List[Step] ): A list of step names or step instances
311+ this `sagemaker.workflow.steps.TransformStep` depends on
292312 """
293313 super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
294314 self .transformer = transformer
@@ -351,7 +371,7 @@ def __init__(
351371 code : str = None ,
352372 property_files : List [PropertyFile ] = None ,
353373 cache_config : CacheConfig = None ,
354- depends_on : List [str ] = None ,
374+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
355375 ):
356376 """Construct a ProcessingStep, given a `Processor` instance.
357377
@@ -372,8 +392,8 @@ def __init__(
372392 property_files (List[PropertyFile]): A list of property files that workflow looks
373393 for and resolves from the configured processing output list.
374394 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
375- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
376- depends on
395+ depends_on (List[str] or List[Step] ): A list of step names or step instance
396+ this `sagemaker.workflow.steps.ProcessingStep` depends on
377397 """
378398 super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
379399 self .processor = processor
0 commit comments