1313from __future__ import absolute_import
1414
1515import sagemaker
16- from sagemaker .fw_utils import create_image_uri , model_code_key_prefix
16+ from sagemaker .fw_utils import model_code_key_prefix
17+ from sagemaker .fw_registry import default_framework_uri
1718from sagemaker .model import FrameworkModel , MODEL_SERVER_WORKERS_PARAM_NAME
1819from sagemaker .predictor import RealTimePredictor , npy_serializer , numpy_deserializer
19- from sagemaker .sklearn .defaults import SKLEARN_VERSION
20+ from sagemaker .sklearn .defaults import SKLEARN_VERSION , SKLEARN_NAME
2021
2122
2223class SKLearnPredictor (RealTimePredictor ):
@@ -40,7 +41,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4041class SKLearnModel (FrameworkModel ):
4142 """An Scikit-learn SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
4243
43- __framework_name__ = 'scikit-learn'
44+ __framework_name__ = SKLEARN_NAME
4445
4546 def __init__ (self , model_data , role , entry_point , image = None , py_version = 'py3' , framework_version = SKLEARN_VERSION ,
4647 predictor_cls = SKLearnPredictor , model_server_workers = None , ** kwargs ):
@@ -77,16 +78,22 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
7778 Args:
7879 instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
7980 accelerator_type (str): The Elastic Inference accelerator type to deploy to the instance for loading and
80- making inferences to the model. For example, 'ml.eia1.medium'.
81+ making inferences to the model. For example, 'ml.eia1.medium'. Note: accelerator types are not
82+ supported by SKLearnModel.
8183
8284 Returns:
8385 dict[str, str]: A container definition object usable with the CreateModel API.
8486 """
87+ if accelerator_type :
88+ raise ValueError ("Accelerator types are not supported for Scikit-Learn." )
89+
8590 deploy_image = self .image
8691 if not deploy_image :
87- region_name = self .sagemaker_session .boto_session .region_name
88- deploy_image = create_image_uri (region_name , self .__framework_name__ , instance_type ,
89- self .framework_version , self .py_version , accelerator_type = accelerator_type )
92+ image_tag = "{}-{}-{}" .format (self .framework_version , "cpu" , self .py_version )
93+ deploy_image = default_framework_uri (
94+ self .__framework_name__ ,
95+ self .sagemaker_session .boto_region_name ,
96+ image_tag )
9097
9198 deploy_key_prefix = model_code_key_prefix (self .key_prefix , self .name , deploy_image )
9299 self ._upload_code (deploy_key_prefix )
0 commit comments