@@ -310,13 +310,34 @@ def _is_valid_model_id_hook():
310310
311311 super (JumpStartModel , self ).__init__ (** model_init_kwargs .to_kwargs_dict ())
312312
313- def _create_sagemaker_model (self , * args , ** kwargs ): # pylint: disable=unused-argument
313+ def _create_sagemaker_model (
314+ self ,
315+ instance_type = None ,
316+ accelerator_type = None ,
317+ tags = None ,
318+ serverless_inference_config = None ,
319+ ** kwargs ,
320+ ):
314321 """Create a SageMaker Model Entity
315322
316323 Args:
317- args: Positional arguments coming from the caller. This class does not require
318- any so they are ignored.
319-
324+ instance_type (str): Optional. The EC2 instance type that this Model will be
325+ used for, this is only used to determine if the image needs GPU
326+ support or not. (Default: None).
327+ accelerator_type (str): Optional. Type of Elastic Inference accelerator to
328+ attach to an endpoint for model loading and inference, for
329+ example, 'ml.eia1.medium'. If not specified, no Elastic
330+ Inference accelerator will be attached to the endpoint. (Default: None).
331+ tags (List[dict[str, str]]): Optional. The list of tags to add to
332+ the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
333+ 'tagvalue'}] For more information about tags, see
334+ https://boto3.amazonaws.com/v1/documentation
335+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
336+ (Default: None).
337+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
338+ Optional. Specifies configuration related to serverless endpoint. Instance type is
339+ not provided in serverless inference. So this is used to find image URIs.
340+ (Default: None).
320341 kwargs: Keyword arguments coming from the caller. This class does not require
321342 any so they are ignored.
322343 """
@@ -347,10 +368,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
347368 container_def ,
348369 vpc_config = self .vpc_config ,
349370 enable_network_isolation = self .enable_network_isolation (),
350- tags = kwargs . get ( " tags" ) ,
371+ tags = tags ,
351372 )
352373 else :
353- super (JumpStartModel , self )._create_sagemaker_model (* args , ** kwargs )
374+ super (JumpStartModel , self )._create_sagemaker_model (
375+ instance_type = instance_type ,
376+ accelerator_type = accelerator_type ,
377+ tags = tags ,
378+ serverless_inference_config = serverless_inference_config ,
379+ ** kwargs ,
380+ )
354381
355382 def deploy (
356383 self ,
0 commit comments