1414
1515from __future__ import absolute_import
1616import logging
17+ import re
1718
1819from typing import Dict , List , Optional , Union
1920from sagemaker .async_inference .async_inference_config import AsyncInferenceConfig
3031)
3132from sagemaker .jumpstart .utils import is_valid_model_id
3233from sagemaker .utils import stringify_object
33- from sagemaker .model import Model
34+ from sagemaker .model import MODEL_PACKAGE_ARN_PATTERN , Model
3435from sagemaker .model_monitor .data_capture_config import DataCaptureConfig
3536from sagemaker .predictor import PredictorBase
3637from sagemaker .serverless .serverless_inference_config import ServerlessInferenceConfig
@@ -71,6 +72,7 @@ def __init__(
7172 container_log_level : Optional [Union [int , PipelineVariable ]] = None ,
7273 dependencies : Optional [List [str ]] = None ,
7374 git_config : Optional [Dict [str , str ]] = None ,
75+ model_package_arn : Optional [str ] = None ,
7476 ):
7577 """Initializes a ``JumpStartModel``.
7678
@@ -249,6 +251,9 @@ def __init__(
249251 >>> 'branch': 'test-branch-git-config',
250252 >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
251253
254+ model_package_arn (Optional[str]): An existing SageMaker Model Package arn,
255+ can be just the name if your account owns the Model Package.
256+ ``model_data`` is not required. (Default: None).
252257 Raises:
253258 ValueError: If the model ID is not recognized by JumpStart.
254259 """
@@ -291,6 +296,7 @@ def _is_valid_model_id_hook():
291296 container_log_level = container_log_level ,
292297 dependencies = dependencies ,
293298 git_config = git_config ,
299+ model_package_arn = model_package_arn ,
294300 )
295301
296302 self .orig_predictor_cls = predictor_cls
@@ -301,9 +307,49 @@ def _is_valid_model_id_hook():
301307 self .tolerate_vulnerable_model = model_init_kwargs .tolerate_vulnerable_model
302308 self .tolerate_deprecated_model = model_init_kwargs .tolerate_deprecated_model
303309 self .region = model_init_kwargs .region
310+ self .model_package_arn = model_init_kwargs .model_package_arn
304311
305312 super (JumpStartModel , self ).__init__ (** model_init_kwargs .to_kwargs_dict ())
306313
314+ def _create_sagemaker_model (self , * args , ** kwargs ): # pylint: disable=unused-argument
315+ """Create a SageMaker Model Entity
316+
317+ Args:
318+ args: Positional arguments coming from the caller. This class does not require
319+ any so they are ignored.
320+
321+ kwargs: Keyword arguments coming from the caller. This class does not require
322+ any so they are ignored.
323+ """
324+ if self .model_package_arn :
325+ # When a ModelPackageArn is provided we just create the Model
326+ match = re .match (MODEL_PACKAGE_ARN_PATTERN , self .model_package_arn )
327+ if match :
328+ model_package_name = match .group (3 )
329+ else :
330+ # model_package_arn can be just the name if your account owns the Model Package
331+ model_package_name = self .model_package_arn
332+ container_def = {"ModelPackageName" : self .model_package_arn }
333+
334+ if self .env != {}:
335+ container_def ["Environment" ] = self .env
336+
337+ if self .name is None :
338+ self ._base_name = model_package_name
339+
340+ self ._set_model_name_if_needed ()
341+
342+ self .sagemaker_session .create_model (
343+ self .name ,
344+ self .role ,
345+ container_def ,
346+ vpc_config = self .vpc_config ,
347+ enable_network_isolation = self .enable_network_isolation (),
348+ tags = kwargs .get ("tags" ),
349+ )
350+ else :
351+ super (JumpStartModel , self )._create_sagemaker_model (* args , ** kwargs )
352+
307353 def deploy (
308354 self ,
309355 initial_instance_count : Optional [int ] = None ,
0 commit comments