@@ -250,7 +250,7 @@ def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
250250 if urlparse (uri ).scheme == "s3" :
251251 bucket , _ = parse_s3_url (uri )
252252
253- return bucket in constants .JUMPSTART_BUCKET_NAME_SET
253+ return bucket in constants .JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET
254254
255255
256256def tag_key_in_array (tag_key : str , tag_array : List [Dict [str , str ]]) -> bool :
@@ -287,7 +287,10 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
287287
288288
289289def add_single_jumpstart_tag (
290- uri : str , tag_key : enums .JumpStartTag , curr_tags : Optional [List [Dict [str , str ]]]
290+ tag_value : str ,
291+ tag_key : enums .JumpStartTag ,
292+ curr_tags : Optional [List [Dict [str , str ]]],
293+ is_uri = False ,
291294) -> Optional [List ]:
292295 """Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
293296
@@ -296,17 +299,28 @@ def add_single_jumpstart_tag(
296299 tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
297300 corresponds to a JumpStart model.
298301 curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
302+ is_uri (boolean): Set to True to indicate a s3 uri is to be tagged. Set to False to indicate
303+ tags for JumpStart model id / version are being added. (Default: False).
299304 """
300- if is_jumpstart_model_uri (uri ):
305+ if not is_uri or is_jumpstart_model_uri (tag_value ):
301306 if curr_tags is None :
302307 curr_tags = []
303308 if not tag_key_in_array (tag_key , curr_tags ):
304- curr_tags .append (
305- {
306- "Key" : tag_key ,
307- "Value" : uri ,
308- }
309+ skip_adding_tag = (
310+ (
311+ tag_key_in_array (enums .JumpStartTag .MODEL_ID , curr_tags )
312+ or tag_key_in_array (enums .JumpStartTag .MODEL_VERSION , curr_tags )
313+ )
314+ if is_uri
315+ else False
309316 )
317+ if not skip_adding_tag :
318+ curr_tags .append (
319+ {
320+ "Key" : tag_key ,
321+ "Value" : tag_value ,
322+ }
323+ )
310324 return curr_tags
311325
312326
@@ -326,14 +340,37 @@ def get_jumpstart_base_name_if_jumpstart_model(
326340 return None
327341
328342
329- def add_jumpstart_tags (
343+ def add_jumpstart_model_id_version_tags (
344+ tags : Optional [List [Dict [str , str ]]],
345+ model_id : str ,
346+ model_version : str ,
347+ ) -> List [Dict [str , str ]]:
348+ """Add custom model ID and version tags to JumpStart related resources."""
349+ if model_id is None or model_version is None :
350+ return tags
351+ tags = add_single_jumpstart_tag (
352+ model_id ,
353+ enums .JumpStartTag .MODEL_ID ,
354+ tags ,
355+ is_uri = False ,
356+ )
357+ tags = add_single_jumpstart_tag (
358+ model_version ,
359+ enums .JumpStartTag .MODEL_VERSION ,
360+ tags ,
361+ is_uri = False ,
362+ )
363+ return tags
364+
365+
366+ def add_jumpstart_uri_tags (
330367 tags : Optional [List [Dict [str , str ]]] = None ,
331368 inference_model_uri : Optional [Union [str , dict ]] = None ,
332369 inference_script_uri : Optional [str ] = None ,
333370 training_model_uri : Optional [str ] = None ,
334371 training_script_uri : Optional [str ] = None ,
335372) -> Optional [List [Dict [str , str ]]]:
336- """Add custom tags to JumpStart models, return the updated tags.
373+ """Add custom uri tags to JumpStart models, return the updated tags.
337374
338375 No-op if this is not a JumpStart model related resource.
339376
@@ -362,31 +399,43 @@ def add_jumpstart_tags(
362399 logging .warning (warn_msg , "inference_model_uri" )
363400 else :
364401 tags = add_single_jumpstart_tag (
365- inference_model_uri , enums .JumpStartTag .INFERENCE_MODEL_URI , tags
402+ inference_model_uri ,
403+ enums .JumpStartTag .INFERENCE_MODEL_URI ,
404+ tags ,
405+ is_uri = True ,
366406 )
367407
368408 if inference_script_uri :
369409 if is_pipeline_variable (inference_script_uri ):
370410 logging .warning (warn_msg , "inference_script_uri" )
371411 else :
372412 tags = add_single_jumpstart_tag (
373- inference_script_uri , enums .JumpStartTag .INFERENCE_SCRIPT_URI , tags
413+ inference_script_uri ,
414+ enums .JumpStartTag .INFERENCE_SCRIPT_URI ,
415+ tags ,
416+ is_uri = True ,
374417 )
375418
376419 if training_model_uri :
377420 if is_pipeline_variable (training_model_uri ):
378421 logging .warning (warn_msg , "training_model_uri" )
379422 else :
380423 tags = add_single_jumpstart_tag (
381- training_model_uri , enums .JumpStartTag .TRAINING_MODEL_URI , tags
424+ training_model_uri ,
425+ enums .JumpStartTag .TRAINING_MODEL_URI ,
426+ tags ,
427+ is_uri = True ,
382428 )
383429
384430 if training_script_uri :
385431 if is_pipeline_variable (training_script_uri ):
386432 logging .warning (warn_msg , "training_script_uri" )
387433 else :
388434 tags = add_single_jumpstart_tag (
389- training_script_uri , enums .JumpStartTag .TRAINING_SCRIPT_URI , tags
435+ training_script_uri ,
436+ enums .JumpStartTag .TRAINING_SCRIPT_URI ,
437+ tags ,
438+ is_uri = True ,
390439 )
391440
392441 return tags
0 commit comments