1616import logging
1717from typing import Union , List , Dict , Optional
1818
19- from sagemaker import Model , PipelineModel
19+ from sagemaker import Model , PipelineModel , Session
2020from sagemaker .workflow ._utils import _RegisterModelStep , _RepackModelStep
2121from sagemaker .workflow .pipeline_context import PipelineSession , _ModelStepArguments
22- from sagemaker .workflow .retry import RetryPolicy
22+ from sagemaker .workflow .retry import RetryPolicy , SageMakerJobStepRetryPolicy
2323from sagemaker .workflow .step_collections import StepCollection
2424from sagemaker .workflow .steps import Step , CreateModelStep
2525
@@ -57,17 +57,72 @@ def __init__(
5757 If a listed `Step` name does not exist, an error is returned (default: None).
5858 retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
5959 policies for the `ModelStep` (default: None).
60+
61+ If a list of retry policies is provided, it would be applied to all steps in the
62+ `ModelStep` collection. Note: in this case, `SageMakerJobStepRetryPolicy`
63+ is not allowed, since create/register model step does not support it.
64+ Please find the example below:
65+
66+ .. code:: python
67+
68+ ModelStep(
69+ ...
70+ retry_policies=[
71+ StepRetryPolicy(...),
72+ ],
73+ )
74+
75+ If a dict is provided, it can specify different retry policies for different
76+ types of steps in the `ModelStep` collection. Similarly,
77+ `SageMakerJobStepRetryPolicy` is not allowed for create/register model step.
78+ See examples below:
79+
80+ .. code:: python
81+
82+ ModelStep(
83+ ...
84+ retry_policies=dict(
85+ register_model_retry_policies=[
86+ StepRetryPolicy(...),
87+ ],
88+ repack_model_retry_policies=[
89+ SageMakerJobStepRetryPolicy(...),
90+ ],
91+ )
92+ )
93+
94+ or
95+
96+ .. code:: python
97+
98+ ModelStep(
99+ ...
100+ retry_policies=dict(
101+ create_model_retry_policies=[
102+ StepRetryPolicy(...),
103+ ],
104+ repack_model_retry_policies=[
105+ SageMakerJobStepRetryPolicy(...),
106+ ],
107+ )
108+ )
109+
60110 display_name (str): The display name of the `ModelStep`.
61111 The display name provides better UI readability. (default: None).
62112 description (str): The description of the `ModelStep` (default: None).
63113 """
64114 # TODO: add a doc link in error message once ready
65- if not isinstance (step_args , _ModelStepArguments ):
66- raise TypeError (
67- "To correctly configure a ModelStep, "
68- "the step_args must be a `_ModelStepArguments` object generated by "
69- ".create() or .register()."
70- )
115+ from sagemaker .workflow .utilities import validate_step_args_input
116+
117+ validate_step_args_input (
118+ step_args = step_args ,
119+ expected_caller = {
120+ Session .create_model .__name__ ,
121+ Session .create_model_package_from_containers .__name__ ,
122+ },
123+ error_message = "The step_args of ModelStep must be obtained from model.create() "
124+ "or model.register()." ,
125+ )
71126 if not (step_args .create_model_request is None ) ^ (
72127 step_args .create_model_package_request is None
73128 ):
@@ -93,7 +148,22 @@ def __init__(
93148 self ._create_model_args = self .step_args .create_model_request
94149 self ._register_model_args = self .step_args .create_model_package_request
95150 self ._need_runtime_repack = self .step_args .need_runtime_repack
151+ self ._assign_and_validate_retry_policies (retry_policies )
152+
153+ if self ._need_runtime_repack :
154+ self ._append_repack_model_step ()
155+ if self ._register_model_args :
156+ self ._append_register_model_step ()
157+ else :
158+ self ._append_create_model_step ()
96159
160+ def _assign_and_validate_retry_policies (self , retry_policies ):
161+ """Assign and validate retry policies according to each kind of sub steps
162+
163+ Args:
164+ retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
165+ policies for the `ModelStep`.
166+ """
97167 if isinstance (retry_policies , dict ):
98168 self ._create_model_retry_policies = retry_policies .get (
99169 _CREATE_MODEL_RETRY_POLICIES , None
@@ -109,12 +179,23 @@ def __init__(
109179 self ._register_model_retry_policies = retry_policies
110180 self ._repack_model_retry_policies = retry_policies
111181
112- if self ._need_runtime_repack :
113- self ._append_repack_model_step ()
114- if self ._register_model_args :
115- self ._append_register_model_step ()
116- else :
117- self ._append_create_model_step ()
182+ self ._validate_sagemaker_job_step_retry_policy ()
183+
184+ def _validate_sagemaker_job_step_retry_policy (self ):
185+ """Validate SageMakerJobStepRetryPolicy
186+
187+ Validate that SageMakerJobStepRetryPolicy is not assigning to create/register model step.
188+ """
189+ retry_policies = set (
190+ (self ._create_model_retry_policies or []) + (self ._register_model_retry_policies or [])
191+ )
192+ for policy in retry_policies :
193+ if not isinstance (policy , SageMakerJobStepRetryPolicy ):
194+ continue
195+ raise ValueError (
196+ "SageMakerJobStepRetryPolicy is not allowed for a create/register"
197+ " model step. Please use StepRetryPolicy instead"
198+ )
118199
119200 def _append_register_model_step (self ):
120201 """Create and append a `_RegisterModelStep`"""
0 commit comments