File tree Expand file tree Collapse file tree 5 files changed +18
-0
lines changed
tests/unit/sagemaker/workflow Expand file tree Collapse file tree 5 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -195,6 +195,7 @@ def _get_model_package_args(
195195 marketplace_cert = False ,
196196 approval_status = None ,
197197 description = None ,
198+ tags = None ,
198199 ):
199200 """Get arguments for session.create_model_package method.
200201
@@ -250,6 +251,8 @@ def _get_model_package_args(
250251 model_package_args ["approval_status" ] = approval_status
251252 if description is not None :
252253 model_package_args ["description" ] = description
254+ if tags is not None :
255+ model_package_args ["tags" ] = tags
253256 return model_package_args
254257
255258 def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
Original file line number Diff line number Diff line change @@ -2724,6 +2724,7 @@ def _get_create_model_package_request(
27242724 marketplace_cert = False ,
27252725 approval_status = "PendingManualApproval" ,
27262726 description = None ,
2727+ tags = None ,
27272728 ):
27282729 """Get request dictionary for CreateModelPackage API.
27292730
@@ -2761,6 +2762,8 @@ def _get_create_model_package_request(
27612762 request_dict ["ModelPackageGroupName" ] = model_package_group_name
27622763 if description is not None :
27632764 request_dict ["ModelPackageDescription" ] = description
2765+ if tags is not None :
2766+ request_dict ["Tags" ] = tags
27642767 if model_metrics :
27652768 request_dict ["ModelMetrics" ] = model_metrics
27662769 if metadata_properties :
Original file line number Diff line number Diff line change @@ -225,6 +225,7 @@ def __init__(
225225 compile_model_family = None ,
226226 description = None ,
227227 depends_on : List [str ] = None ,
228+ tags = None ,
228229 ** kwargs ,
229230 ):
230231 """Constructor of a register model step.
@@ -264,6 +265,7 @@ def __init__(
264265 self .inference_instances = inference_instances
265266 self .transform_instances = transform_instances
266267 self .model_package_group_name = model_package_group_name
268+ self .tags = tags
267269 self .model_metrics = model_metrics
268270 self .metadata_properties = metadata_properties
269271 self .approval_status = approval_status
@@ -324,10 +326,12 @@ def arguments(self) -> RequestType:
324326 metadata_properties = self .metadata_properties ,
325327 approval_status = self .approval_status ,
326328 description = self .description ,
329+ tags = self .tags ,
327330 )
328331 request_dict = model .sagemaker_session ._get_create_model_package_request (
329332 ** model_package_args
330333 )
334+
331335 # these are not available in the workflow service and will cause rejection
332336 if "CertifyForMarketplace" in request_dict :
333337 request_dict .pop ("CertifyForMarketplace" )
Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def __init__(
6767 image_uri = None ,
6868 compile_model_family = None ,
6969 description = None ,
70+ tags = None ,
7071 ** kwargs ,
7172 ):
7273 """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,10 @@ def __init__(
9495 compile_model_family (str): The instance family for the compiled model. If
9596 specified, a compiled model is used (default: None).
9697 description (str): Model Package description (default: None).
98+ tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note
99+ that tags will only be applied to newly created model package groups; if the
100+ name of an existing group is passed to "model_package_group_name",
101+ tags will not be applied.
97102 **kwargs: additional arguments to `create_model`.
98103 """
99104 steps : List [Step ] = []
@@ -134,6 +139,7 @@ def __init__(
134139 image_uri = image_uri ,
135140 compile_model_family = compile_model_family ,
136141 description = description ,
142+ tags = tags ,
137143 ** kwargs ,
138144 )
139145 if not repack_model :
Original file line number Diff line number Diff line change @@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics):
182182 approval_status = "Approved" ,
183183 description = "description" ,
184184 depends_on = ["TestStep" ],
185+ tags = [{"Key" : "myKey" , "Value" : "myValue" }],
185186 )
186187 assert ordered (register_model .request_dicts ()) == ordered (
187188 [
@@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics):
210211 },
211212 "ModelPackageDescription" : "description" ,
212213 "ModelPackageGroupName" : "mpg" ,
214+ "Tags" : [{"Key" : "myKey" , "Value" : "myValue" }],
213215 },
214216 },
215217 ]
You can’t perform that action at this time.
0 commit comments