2929_REGISTER_MODEL_NAME_BASE = "RegisterModel"
3030_CREATE_MODEL_NAME_BASE = "CreateModel"
3131_REPACK_MODEL_NAME_BASE = "RepackModel"
32+ _IGNORED_REPACK_PARAM_LIST = ["entry_point" , "source_dir" , "hyperparameters" , "dependencies" ]
33+
34+ logger = logging .getLogger (__name__ )
3235
3336
3437class ModelStep (StepCollection ):
@@ -42,6 +45,7 @@ def __init__(
4245 retry_policies : Optional [Union [List [RetryPolicy ], Dict [str , List [RetryPolicy ]]]] = None ,
4346 display_name : Optional [str ] = None ,
4447 description : Optional [str ] = None ,
48+ repack_model_step_settings : Optional [Dict [str , any ]] = None ,
4549 ):
4650 """Constructs a `ModelStep`.
4751
@@ -115,6 +119,15 @@ def __init__(
115119 display_name (str): The display name of the `ModelStep`.
116120 The display name provides better UI readability. (default: None).
117121 description (str): The description of the `ModelStep` (default: None).
122+ repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep
123+ to customize the configuration of the underlying repack model job (default: None).
124+ Notes:
125+ 1. If the _RepackModelStep is unnecessary, the settings will be ignored.
126+ 2. If the _RepackModelStep is added, the repack_model_step_settings
127+ is honored if set.
128+ 3. In repack_model_step_settings, the arguments with misspelled keys will be
129+ ignored. Please refer to the expected parameters of repack model job in
130+ :class:`~sagemaker.sklearn.estimator.SKLearn` and its base classes.
118131 """
119132 from sagemaker .workflow .utilities import validate_step_args_input
120133
@@ -148,6 +161,9 @@ def __init__(
148161 self .display_name = display_name
149162 self .description = description
150163 self .steps : List [Step ] = []
164+ self ._repack_model_step_settings = (
165+ dict (repack_model_step_settings ) if repack_model_step_settings else {}
166+ )
151167 self ._model = step_args .model
152168 self ._create_model_args = self .step_args .create_model_request
153169 self ._register_model_args = self .step_args .create_model_package_request
@@ -157,6 +173,12 @@ def __init__(
157173
158174 if self ._need_runtime_repack :
159175 self ._append_repack_model_step ()
176+ elif self ._repack_model_step_settings :
177+ logger .warning (
178+ "Non-empty repack_model_step_settings is supplied but no repack model "
179+ "step is needed. Ignoring the repack_model_step_settings."
180+ )
181+
160182 if self ._register_model_args :
161183 self ._append_register_model_step ()
162184 else :
@@ -235,14 +257,12 @@ def _append_repack_model_step(self):
235257 elif isinstance (self ._model , Model ):
236258 model_list = [self ._model ]
237259 else :
238- logging .warning ("No models to repack" )
260+ logger .warning ("No models to repack" )
239261 return
240262
241- security_group_ids = None
242- subnets = None
243- if self ._model .vpc_config :
244- security_group_ids = self ._model .vpc_config .get ("SecurityGroupIds" , None )
245- subnets = self ._model .vpc_config .get ("Subnets" , None )
263+ self ._pop_out_non_configurable_repack_model_step_args ()
264+
265+ security_group_ids , subnets = self ._resolve_repack_model_step_vpc_configs ()
246266
247267 for i , model in enumerate (model_list ):
248268 runtime_repack_flg = (
@@ -252,8 +272,16 @@ def _append_repack_model_step(self):
252272 name_base = model .name or i
253273 repack_model_step = _RepackModelStep (
254274 name = "{}-{}-{}" .format (self .name , _REPACK_MODEL_NAME_BASE , name_base ),
255- sagemaker_session = self ._model .sagemaker_session or model .sagemaker_session ,
256- role = self ._model .role or model .role ,
275+ sagemaker_session = (
276+ self ._repack_model_step_settings .pop ("sagemaker_session" , None )
277+ or self ._model .sagemaker_session
278+ or model .sagemaker_session
279+ ),
280+ role = (
281+ self ._repack_model_step_settings .pop ("role" , None )
282+ or self ._model .role
283+ or model .role
284+ ),
257285 model_data = model .model_data ,
258286 entry_point = model .entry_point ,
259287 source_dir = model .source_dir ,
@@ -266,8 +294,15 @@ def _append_repack_model_step(self):
266294 ),
267295 depends_on = self .depends_on ,
268296 retry_policies = self ._repack_model_retry_policies ,
269- output_path = self ._runtime_repack_output_prefix ,
270- output_kms_key = model .model_kms_key ,
297+ output_path = (
298+ self ._repack_model_step_settings .pop ("output_path" , None )
299+ or self ._runtime_repack_output_prefix
300+ ),
301+ output_kms_key = (
302+ self ._repack_model_step_settings .pop ("output_kms_key" , None )
303+ or model .model_kms_key
304+ ),
305+ ** self ._repack_model_step_settings
271306 )
272307 self .steps .append (repack_model_step )
273308
@@ -282,3 +317,32 @@ def _append_repack_model_step(self):
282317 "InferenceSpecification"
283318 ]["Containers" ][i ]
284319 container ["ModelDataUrl" ] = repacked_model_data
320+
321+ def _pop_out_non_configurable_repack_model_step_args (self ):
322+ """Pop out non-configurable args from _repack_model_step_settings"""
323+ if not self ._repack_model_step_settings :
324+ return
325+ for ignored_param in _IGNORED_REPACK_PARAM_LIST :
326+ if self ._repack_model_step_settings .pop (ignored_param , None ):
327+ logger .warning (
328+ "The repack model step parameter - %s is not configurable. Ignoring it." ,
329+ ignored_param ,
330+ )
331+
332+ def _resolve_repack_model_step_vpc_configs (self ):
333+ """Resolve vpc configs for repack model step"""
334+ # Note: the EstimatorBase constructor ensures that:
335+ # "When setting up custom VPC, both subnets and security_group_ids must be set"
336+ if self ._repack_model_step_settings .get (
337+ "security_group_ids" , None
338+ ) or self ._repack_model_step_settings .get ("subnets" , None ):
339+ security_group_ids = self ._repack_model_step_settings .pop ("security_group_ids" , None )
340+ subnets = self ._repack_model_step_settings .pop ("subnets" , None )
341+ return security_group_ids , subnets
342+
343+ if self ._model .vpc_config :
344+ security_group_ids = self ._model .vpc_config .get ("SecurityGroupIds" , None )
345+ subnets = self ._model .vpc_config .get ("Subnets" , None )
346+ return security_group_ids , subnets
347+
348+ return None , None
0 commit comments