Skip to content

Commit 1d1233d

Browse files
author
Jonathan Makunga
committed
refactoring
1 parent f68b71b commit 1d1233d

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ def _optimize_for_jumpstart(
670670
output_path: Optional[str] = None,
671671
tags: Optional[Tags] = None,
672672
job_name: Optional[str] = None,
673+
instance_type: Optional[str] = None,
673674
accept_eula: Optional[bool] = None,
674675
quantization_config: Optional[Dict] = None,
675676
compilation_config: Optional[Dict] = None,
@@ -685,6 +686,7 @@ def _optimize_for_jumpstart(
685686
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
686687
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
687688
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
689+
instance_type (str): Target deployment instance type that the model is optimized for.
688690
accept_eula (bool): For models that require a Model Access Config, specify True or
689691
False to indicate whether model terms of use have been accepted.
690692
The `accept_eula` value must be explicitly defined as `True` in order to
@@ -711,12 +713,12 @@ def _optimize_for_jumpstart(
711713
)
712714

713715
is_compilation = (not quantization_config) and (
714-
(compilation_config is not None) or _is_inferentia_or_trainium(self.instance_type)
716+
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
715717
)
716718

717719
pysdk_model_env_vars = dict()
718720
if is_compilation:
719-
pysdk_model_env_vars = self._get_neuron_model_env_vars(self.instance_type)
721+
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
720722

721723
optimization_config, override_env = _extract_optimization_config_and_env(
722724
quantization_config, compilation_config
@@ -752,9 +754,7 @@ def _optimize_for_jumpstart(
752754
if self.pysdk_model.deployment_config
753755
else None
754756
)
755-
self.instance_type = (
756-
self.instance_type or deployment_config_instance_type or _get_nb_instance()
757-
)
757+
self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance()
758758

759759
create_optimization_job_args = {
760760
"OptimizationJobName": job_name,

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,7 @@ def _model_builder_optimize_wrapper(
12291229
output_path=output_path,
12301230
tags=tags,
12311231
job_name=job_name,
1232+
instance_type=instance_type,
12321233
accept_eula=accept_eula,
12331234
quantization_config=quantization_config,
12341235
compilation_config=compilation_config,

0 commit comments

Comments
 (0)