@@ -64,12 +64,13 @@ class Step(Entity):
6464 Attributes:
6565 name (str): The name of the step.
6666 step_type (StepTypeEnum): The type of the step.
67- depends_on (List[str]): The list of step names the current step depends on
67+ depends_on (List[str] or List[Step]): The list of step names or step
68+ instances the current step depends on
6869 """
6970
7071 name : str = attr .ib (factory = str )
7172 step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
72- depends_on : List [str ] = attr .ib (default = None )
73+ depends_on : Union [ List [str ], List [ "Step" ] ] = attr .ib (default = None )
7374
7475 @property
7576 @abc .abstractmethod
@@ -89,11 +90,13 @@ def to_request(self) -> RequestType:
8990 "Arguments" : self .arguments ,
9091 }
9192 if self .depends_on :
92- request_dict ["DependsOn" ] = self .depends_on
93+ request_dict ["DependsOn" ] = self ._resolve_depends_on (self .depends_on )
94+
9395 return request_dict
9496
95- def add_depends_on (self , step_names : List [str ]):
96- """Add step names to the current step depends on list"""
97+ def add_depends_on (self , step_names : Union [List [str ], List ["Step" ]]):
98+ """Add step names or step instances to the current step depends on list"""
99+
97100 if not step_names :
98101 return
99102
@@ -106,6 +109,19 @@ def ref(self) -> Dict[str, str]:
106109 """Gets a reference dict for steps"""
107110 return {"Name" : self .name }
108111
112+ @staticmethod
113+ def _resolve_depends_on (depends_on_list : Union [List [str ], List ["Step" ]]):
114+ """Resolver the step depends on list"""
115+ depends_on = []
116+ for step in depends_on_list :
117+ if isinstance (step , Step ):
118+ depends_on .append (step .name )
119+ elif isinstance (step , str ):
120+ depends_on .append (step )
121+ else :
122+ raise ValueError (f"Invalid input step name: { step } " )
123+ return depends_on
124+
109125
110126@attr .s
111127class CacheConfig :
@@ -154,7 +170,7 @@ def __init__(
154170 estimator : EstimatorBase ,
155171 inputs : Union [TrainingInput , dict , str , FileSystemInput ] = None ,
156172 cache_config : CacheConfig = None ,
157- depends_on : List [str ] = None ,
173+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
158174 ):
159175 """Construct a TrainingStep, given an `EstimatorBase` instance.
160176
@@ -181,8 +197,8 @@ def __init__(
181197 the path to the training dataset.
182198
183199 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
184- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
185- depends on
200+ depends_on (List[str] or List[Step] ): A list of step names or step instances
201+ this `sagemaker.workflow.steps.TrainingStep` depends on
186202 """
187203 super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
188204 self .estimator = estimator
@@ -227,7 +243,11 @@ class CreateModelStep(Step):
227243 """CreateModel step for workflow."""
228244
229245 def __init__ (
230- self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
246+ self ,
247+ name : str ,
248+ model : Model ,
249+ inputs : CreateModelInput ,
250+ depends_on : Union [List [str ], List [Step ]] = None ,
231251 ):
232252 """Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
233253
@@ -239,8 +259,8 @@ def __init__(
239259 model (Model): A `sagemaker.model.Model` instance.
240260 inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
241261 Defaults to `None`.
242- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
243- depends on
262+ depends_on (List[str] or List[Step] ): A list of step names or step instances
263+ this `sagemaker.workflow.steps.CreateModelStep` depends on
244264 """
245265 super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
246266 self .model = model
@@ -285,7 +305,7 @@ def __init__(
285305 transformer : Transformer ,
286306 inputs : TransformInput ,
287307 cache_config : CacheConfig = None ,
288- depends_on : List [str ] = None ,
308+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
289309 ):
290310 """Constructs a TransformStep, given an `Transformer` instance.
291311
@@ -297,8 +317,8 @@ def __init__(
297317 transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
298318 inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
299319 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
300- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
301- depends on
320+ depends_on (List[str] or List[Step] ): A list of step names or step instances
321+ this `sagemaker.workflow.steps.TransformStep` depends on
302322 """
303323 super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
304324 self .transformer = transformer
@@ -361,7 +381,7 @@ def __init__(
361381 code : str = None ,
362382 property_files : List [PropertyFile ] = None ,
363383 cache_config : CacheConfig = None ,
364- depends_on : List [str ] = None ,
384+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
365385 ):
366386 """Construct a ProcessingStep, given a `Processor` instance.
367387
@@ -382,8 +402,8 @@ def __init__(
382402 property_files (List[PropertyFile]): A list of property files that workflow looks
383403 for and resolves from the configured processing output list.
384404 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
385- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
386- depends on
405+ depends_on (List[str] or List[Step] ): A list of step names or step instance
406+ this `sagemaker.workflow.steps.ProcessingStep` depends on
387407 """
388408 super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
389409 self .processor = processor
@@ -451,7 +471,7 @@ def __init__(
451471 inputs = None ,
452472 job_arguments : List [str ] = None ,
453473 cache_config : CacheConfig = None ,
454- depends_on : List [str ] = None ,
474+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
455475 ):
456476 """Construct a TuningStep, given a `HyperparameterTuner` instance.
457477
@@ -491,8 +511,8 @@ def __init__(
491511 job_arguments (List[str]): A list of strings to be passed into the processing job.
492512 Defaults to `None`.
493513 cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
494- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
495- depends on
514+ depends_on (List[str] or List[Step] ): A list of step names or step instance
515+ this `sagemaker.workflow.steps.ProcessingStep` depends on
496516 """
497517 super (TuningStep , self ).__init__ (name , StepTypeEnum .TUNING , depends_on )
498518 self .tuner = tuner
@@ -545,7 +565,7 @@ def to_request(self) -> RequestType:
545565
546566 return request_dict
547567
548- def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ):
568+ def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ) -> Join :
549569 """Get the model artifact s3 uri from the top performing training jobs.
550570
551571 Args:
0 commit comments