9898 to_string ,
9999 check_and_get_run_experiment_config ,
100100 resolve_value_from_config ,
101+ format_tags ,
102+ Tags ,
101103)
102104from sagemaker .workflow import is_pipeline_variable
103105from sagemaker .workflow .entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144146 output_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
145147 base_job_name : Optional [str ] = None ,
146148 sagemaker_session : Optional [Session ] = None ,
147- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
149+ tags : Optional [Tags ] = None ,
148150 subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
149151 security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
150152 model_uri : Optional [str ] = None ,
@@ -270,8 +272,8 @@ def __init__(
270272 manages interactions with Amazon SageMaker APIs and any other
271273 AWS services needed. If not specified, the estimator creates one
272274 using the default AWS configuration chain.
273- tags (list[dict[str, str] or list[dict[str, PipelineVariable] ]):
274- List of tags for labeling a training job. For more, see
275+ tags (Optional[Tags ]):
276+ Tags for labeling a training job. For more, see
275277 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
276278 subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
277279 specified training job will be created without VPC config.
@@ -604,6 +606,7 @@ def __init__(
604606 else :
605607 self .sagemaker_session = sagemaker_session or Session ()
606608
609+ tags = format_tags (tags )
607610 self .tags = (
608611 add_jumpstart_uri_tags (
609612 tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
@@ -1352,7 +1355,7 @@ def compile_model(
13521355 framework = None ,
13531356 framework_version = None ,
13541357 compile_max_run = 15 * 60 ,
1355- tags = None ,
1358+ tags : Optional [ Tags ] = None ,
13561359 target_platform_os = None ,
13571360 target_platform_arch = None ,
13581361 target_platform_accelerator = None ,
@@ -1378,7 +1381,7 @@ def compile_model(
13781381 compile_max_run (int): Timeout in seconds for compilation (default:
13791382 15 * 60). After this amount of time Amazon SageMaker Neo
13801383 terminates the compilation job regardless of its current status.
1381- tags (list[dict]): List of tags for labeling a compilation job. For
1384+ tags (list[dict]): Tags for labeling a compilation job. For
13821385 more, see
13831386 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
13841387 target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1420,7 +1423,7 @@ def compile_model(
14201423 input_shape ,
14211424 output_path ,
14221425 self .role ,
1423- tags ,
1426+ format_tags ( tags ) ,
14241427 self ._compilation_job_name (),
14251428 compile_max_run ,
14261429 framework = framework ,
@@ -1532,7 +1535,7 @@ def deploy(
15321535 model_name = None ,
15331536 kms_key = None ,
15341537 data_capture_config = None ,
1535- tags = None ,
1538+ tags : Optional [ Tags ] = None ,
15361539 serverless_inference_config = None ,
15371540 async_inference_config = None ,
15381541 volume_size = None ,
@@ -1601,8 +1604,10 @@ def deploy(
16011604 empty object passed through, will use pre-defined values in
16021605 ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
16031606 instance based endpoint if it's None. (default: None)
1604- tags(List[dict[str, str]] ): Optional. The list of tags to attach to this specific
1607+ tags(Optional[Tags] ): Optional. Tags to attach to this specific
16051608 endpoint. Example:
1609+ >>> tags = {'tagname', 'tagvalue'}
1610+ Or
16061611 >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
16071612 For more information about tags, see
16081613 https://boto3.amazonaws.com/v1/documentation\
@@ -1664,7 +1669,7 @@ def deploy(
16641669 model .name = model_name
16651670
16661671 tags = update_inference_tags_with_jumpstart_training_tags (
1667- inference_tags = tags , training_tags = self .tags
1672+ inference_tags = format_tags ( tags ) , training_tags = self .tags
16681673 )
16691674
16701675 return model .deploy (
@@ -2017,7 +2022,7 @@ def transformer(
20172022 env = None ,
20182023 max_concurrent_transforms = None ,
20192024 max_payload = None ,
2020- tags = None ,
2025+ tags : Optional [ Tags ] = None ,
20212026 role = None ,
20222027 volume_kms_key = None ,
20232028 vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
@@ -2051,7 +2056,7 @@ def transformer(
20512056 to be made to each individual transform container at one time.
20522057 max_payload (int): Maximum size of the payload in a single HTTP
20532058 request to the container in MB.
2054- tags (list[dict ]): List of tags for labeling a transform job. If
2059+ tags (Optional[Tags ]): Tags for labeling a transform job. If
20552060 none specified, then the tags used for the training job are used
20562061 for the transform job.
20572062 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2078,7 +2083,7 @@ def transformer(
20782083 model. If not specified, the estimator generates a default job name
20792084 based on the training image name and current timestamp.
20802085 """
2081- tags = tags or self .tags
2086+ tags = format_tags ( tags ) or self .tags
20822087 model_name = self ._get_or_create_name (model_name )
20832088
20842089 if self .latest_training_job is None :
@@ -2717,7 +2722,7 @@ def __init__(
27172722 base_job_name : Optional [str ] = None ,
27182723 sagemaker_session : Optional [Session ] = None ,
27192724 hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
2720- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
2725+ tags : Optional [Tags ] = None ,
27212726 subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
27222727 security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
27232728 model_uri : Optional [str ] = None ,
@@ -2847,7 +2852,7 @@ def __init__(
28472852 hyperparameters. SageMaker rejects the training job request and returns an
28482853 validation error for detected credentials, if such user input is found.
28492854
2850- tags (list[dict[str, str] or list[dict[str, PipelineVariable]] ): List of tags for
2855+ tags (Optional[Tags] ): Tags for
28512856 labeling a training job. For more, see
28522857 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
28532858 subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3130,7 +3135,7 @@ def __init__(
31303135 output_kms_key ,
31313136 base_job_name ,
31323137 sagemaker_session ,
3133- tags ,
3138+ format_tags ( tags ) ,
31343139 subnets ,
31353140 security_group_ids ,
31363141 model_uri = model_uri ,
@@ -3762,7 +3767,7 @@ def transformer(
37623767 env = None ,
37633768 max_concurrent_transforms = None ,
37643769 max_payload = None ,
3765- tags = None ,
3770+ tags : Optional [ Tags ] = None ,
37663771 role = None ,
37673772 model_server_workers = None ,
37683773 volume_kms_key = None ,
@@ -3798,7 +3803,7 @@ def transformer(
37983803 to be made to each individual transform container at one time.
37993804 max_payload (int): Maximum size of the payload in a single HTTP
38003805 request to the container in MB.
3801- tags (list[dict ]): List of tags for labeling a transform job. If
3806+ tags (Optional[Tags ]): Tags for labeling a transform job. If
38023807 none specified, then the tags used for the training job are used
38033808 for the transform job.
38043809 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3837,7 +3842,7 @@ def transformer(
38373842 SageMaker Batch Transform job.
38383843 """
38393844 role = role or self .role
3840- tags = tags or self .tags
3845+ tags = format_tags ( tags ) or self .tags
38413846 model_name = self ._get_or_create_name (model_name )
38423847
38433848 if self .latest_training_job is not None :
0 commit comments