1717import shutil
1818import tarfile
1919import tempfile
20- from typing import List , Union , Optional
20+ from typing import List , Union , Optional , TYPE_CHECKING
2121from sagemaker import image_uris
2222from sagemaker .inputs import TrainingInput
2323from sagemaker .estimator import EstimatorBase
3434from sagemaker .utils import _save_model , download_file_from_url
3535from sagemaker .workflow .retry import RetryPolicy
3636
37+ if TYPE_CHECKING :
38+ from sagemaker .workflow .step_collections import StepCollection
39+
3740FRAMEWORK_VERSION = "0.23-1"
3841INSTANCE_TYPE = "ml.m5.large"
3942REPACK_SCRIPT = "_repack_model.py"
@@ -57,7 +60,7 @@ def __init__(
5760 description : str = None ,
5861 source_dir : str = None ,
5962 dependencies : List = None ,
60- depends_on : Union [List [str ], List [ Step ]] = None ,
63+ depends_on : Optional [List [Union [ str , Step , "StepCollection" ] ]] = None ,
6164 retry_policies : List [RetryPolicy ] = None ,
6265 subnets = None ,
6366 security_group_ids = None ,
@@ -124,8 +127,9 @@ def __init__(
124127 >>> |------ virtual-env
125128
126129 This is not supported with "local code" in Local Mode.
127- depends_on (List[str] or List[Step]): A list of step names or instances
128- this step depends on (default: None).
130+ depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
131+ names or `Step` instances or `StepCollection` instances that the current `Step`
132+ depends on (default: None).
129133 retry_policies (List[RetryPolicy]): The list of retry policies for the current step
130134 (default: None).
131135 subnets (list[str]): List of subnet ids. If not specified, the re-packing
@@ -274,7 +278,7 @@ def __init__(
274278 compile_model_family = None ,
275279 display_name : str = None ,
276280 description = None ,
277- depends_on : Optional [Union [ List [str ], List [ Step ]]] = None ,
281+ depends_on : Optional [List [Union [ str , Step , "StepCollection" ]]] = None ,
278282 retry_policies : Optional [List [RetryPolicy ]] = None ,
279283 tags = None ,
280284 container_def_list = None ,
@@ -311,8 +315,9 @@ def __init__(
311315 if specified, a compiled model will be used (default: None).
312316 display_name (str): The display name of this `_RegisterModelStep` step (default: None).
313317 description (str): Model Package description (default: None).
314- depends_on (List[str] or List[Step]): A list of step names or instances
315- this step depends on (default: None).
318+ depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
319+ names or `Step` instances or `StepCollection` instances that the current `Step`
320+ depends on (default: None).
316321 retry_policies (List[RetryPolicy]): The list of retry policies for the current step
317322 (default: None).
318323 tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to
0 commit comments