@@ -109,7 +109,7 @@ def __init__(self):
109
109
self .prepared_for_djl = None
110
110
self .prepared_for_mms = None
111
111
self .schema_builder = None
112
- self .nb_instance_type = None
112
+ self .instance_type = None
113
113
self .ram_usage_model_load = None
114
114
self .model_hub = None
115
115
self .model_metadata = None
@@ -138,7 +138,9 @@ def _is_jumpstart_model_id(self) -> bool:
138
138
139
139
def _create_pre_trained_js_model (self ) -> Type [Model ]:
140
140
"""Placeholder docstring"""
141
- pysdk_model = JumpStartModel (self .model , vpc_config = self .vpc_config )
141
+ pysdk_model = JumpStartModel (
142
+ self .model , vpc_config = self .vpc_config , instance_type = self .instance_type
143
+ )
142
144
pysdk_model .sagemaker_session = self .sagemaker_session
143
145
144
146
self ._original_deploy = pysdk_model .deploy
@@ -234,8 +236,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
234
236
235
237
if "endpoint_logging" not in kwargs :
236
238
kwargs ["endpoint_logging" ] = True
237
- if hasattr ( self , "nb_instance_type" ) :
238
- kwargs .update ({"instance_type" : self .nb_instance_type })
239
+ if self . instance_type :
240
+ kwargs .update ({"instance_type" : self .instance_type })
239
241
240
242
if "mode" in kwargs :
241
243
del kwargs ["mode" ]
@@ -268,7 +270,7 @@ def _build_for_djl_jumpstart(self):
268
270
)
269
271
self ._prepare_for_mode ()
270
272
elif self .mode == Mode .SAGEMAKER_ENDPOINT and hasattr (self , "prepared_for_djl" ):
271
- self .nb_instance_type = _get_nb_instance ()
273
+ self .instance_type = self . instance_type or _get_nb_instance ()
272
274
self .pysdk_model .model_data , env = self ._prepare_for_mode ()
273
275
274
276
self .pysdk_model .env .update (env )
@@ -647,7 +649,7 @@ def _optimize_for_jumpstart(
647
649
self ,
648
650
output_path : Optional [str ] = None ,
649
651
instance_type : Optional [str ] = None ,
650
- role : Optional [str ] = None ,
652
+ role_arn : Optional [str ] = None ,
651
653
tags : Optional [Tags ] = None ,
652
654
job_name : Optional [str ] = None ,
653
655
accept_eula : Optional [bool ] = None ,
@@ -665,7 +667,7 @@ def _optimize_for_jumpstart(
665
667
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
666
668
instance_type (Optional[str]): Target deployment instance type that
667
669
the model is optimized for.
668
- role (Optional[str]): Execution role. Defaults to ``None``.
670
+ role_arn (Optional[str]): Execution role. Defaults to ``None``.
669
671
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
670
672
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
671
673
accept_eula (bool): For models that require a Model Access Config, specify True or
@@ -735,7 +737,7 @@ def _optimize_for_jumpstart(
735
737
"DeploymentInstanceType" : instance_type ,
736
738
"OptimizationConfigs" : [optimization_config ],
737
739
"OutputConfig" : output_config ,
738
- "RoleArn" : role ,
740
+ "RoleArn" : role_arn ,
739
741
}
740
742
741
743
if optimization_env_vars :
0 commit comments