Skip to content

Commit 0ac6014

Browse files
Jonathan Makungamufaddal-rohawala
authored andcommitted
JS Optimize api ref
1 parent 114a716 commit 0ac6014

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self):
109109
self.prepared_for_djl = None
110110
self.prepared_for_mms = None
111111
self.schema_builder = None
112-
self.nb_instance_type = None
112+
self.instance_type = None
113113
self.ram_usage_model_load = None
114114
self.model_hub = None
115115
self.model_metadata = None
@@ -138,7 +138,9 @@ def _is_jumpstart_model_id(self) -> bool:
138138

139139
def _create_pre_trained_js_model(self) -> Type[Model]:
140140
"""Placeholder docstring"""
141-
pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config)
141+
pysdk_model = JumpStartModel(
142+
self.model, vpc_config=self.vpc_config, instance_type=self.instance_type
143+
)
142144
pysdk_model.sagemaker_session = self.sagemaker_session
143145

144146
self._original_deploy = pysdk_model.deploy
@@ -234,8 +236,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
234236

235237
if "endpoint_logging" not in kwargs:
236238
kwargs["endpoint_logging"] = True
237-
if hasattr(self, "nb_instance_type"):
238-
kwargs.update({"instance_type": self.nb_instance_type})
239+
if self.instance_type:
240+
kwargs.update({"instance_type": self.instance_type})
239241

240242
if "mode" in kwargs:
241243
del kwargs["mode"]
@@ -268,7 +270,7 @@ def _build_for_djl_jumpstart(self):
268270
)
269271
self._prepare_for_mode()
270272
elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"):
271-
self.nb_instance_type = _get_nb_instance()
273+
self.instance_type = self.instance_type or _get_nb_instance()
272274
self.pysdk_model.model_data, env = self._prepare_for_mode()
273275

274276
self.pysdk_model.env.update(env)
@@ -647,7 +649,7 @@ def _optimize_for_jumpstart(
647649
self,
648650
output_path: Optional[str] = None,
649651
instance_type: Optional[str] = None,
650-
role: Optional[str] = None,
652+
role_arn: Optional[str] = None,
651653
tags: Optional[Tags] = None,
652654
job_name: Optional[str] = None,
653655
accept_eula: Optional[bool] = None,
@@ -665,7 +667,7 @@ def _optimize_for_jumpstart(
665667
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
666668
instance_type (Optional[str]): Target deployment instance type that
667669
the model is optimized for.
668-
role (Optional[str]): Execution role. Defaults to ``None``.
670+
role_arn (Optional[str]): Execution role. Defaults to ``None``.
669671
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
670672
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
671673
accept_eula (bool): For models that require a Model Access Config, specify True or
@@ -735,7 +737,7 @@ def _optimize_for_jumpstart(
735737
"DeploymentInstanceType": instance_type,
736738
"OptimizationConfigs": [optimization_config],
737739
"OutputConfig": output_config,
738-
"RoleArn": role,
740+
"RoleArn": role_arn,
739741
}
740742

741743
if optimization_env_vars:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def optimize(self, *args, **kwargs) -> Model:
949949
instance_type (Optional[str]): Target deployment instance type that the
950950
model is optimized for.
951951
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
952-
role (Optional[str]): Execution role. Defaults to ``None``.
952+
role_arn (Optional[str]): Execution role. Defaults to ``None``.
953953
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
954954
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
955955
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -978,7 +978,7 @@ def _model_builder_optimize_wrapper(
978978
self,
979979
output_path: Optional[str] = None,
980980
instance_type: Optional[str] = None,
981-
role: Optional[str] = None,
981+
role_arn: Optional[str] = None,
982982
tags: Optional[Tags] = None,
983983
job_name: Optional[str] = None,
984984
accept_eula: Optional[bool] = None,
@@ -996,7 +996,7 @@ def _model_builder_optimize_wrapper(
996996
Args:
997997
output_path (str): Specifies where to store the compiled/quantized model.
998998
instance_type (str): Target deployment instance type that the model is optimized for.
999-
role (Optional[str]): Execution role. Defaults to ``None``.
999+
role_arn (Optional[str]): Execution role arn. Defaults to ``None``.
10001000
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
10011001
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
10021002
accept_eula (bool): For models that require a Model Access Config, specify True or
@@ -1030,8 +1030,8 @@ def _model_builder_optimize_wrapper(
10301030

10311031
if instance_type:
10321032
self.instance_type = instance_type
1033-
if role:
1034-
self.role = role
1033+
if role_arn:
1034+
self.role_arn = role_arn
10351035

10361036
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
10371037
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
@@ -1041,7 +1041,7 @@ def _model_builder_optimize_wrapper(
10411041
input_args = self._optimize_for_jumpstart(
10421042
output_path=output_path,
10431043
instance_type=instance_type,
1044-
role=role if role else self.role_arn,
1044+
role_arn=self.role_arn,
10451045
tags=tags,
10461046
job_name=job_name,
10471047
accept_eula=accept_eula,

0 commit comments

Comments
 (0)