Skip to content

Commit 80fb96a

Browse files
Jonathan Makungamufaddal-rohawala
authored andcommitted
Refactoring
1 parent 0ac6014 commit 80fb96a

File tree

5 files changed

+26
-77
lines changed

5 files changed

+26
-77
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,6 +2569,8 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
25692569
"model_data_download_timeout",
25702570
"container_startup_health_check_timeout",
25712571
"additional_data_sources",
2572+
"neuron_model_id",
2573+
"neuron_model_version",
25722574
]
25732575

25742576
def __init__(
@@ -2599,6 +2601,8 @@ def __init__(
25992601
"supported_inference_instance_types"
26002602
)
26012603
self.additional_data_sources = resolved_config.get("hosting_additional_data_sources")
2604+
self.neuron_model_id = resolved_config.get("hosting_neuron_model_id")
2605+
self.neuron_model_version = resolved_config.get("hosting_neuron_model_version")
26022606

26032607

26042608
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self):
110110
self.prepared_for_mms = None
111111
self.schema_builder = None
112112
self.instance_type = None
113+
self.nb_instance_type = None
113114
self.ram_usage_model_load = None
114115
self.model_hub = None
115116
self.model_metadata = None
@@ -236,8 +237,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
236237

237238
if "endpoint_logging" not in kwargs:
238239
kwargs["endpoint_logging"] = True
239-
if self.instance_type:
240-
kwargs.update({"instance_type": self.instance_type})
240+
if hasattr(self, "nb_instance_type"):
241+
kwargs.update({"instance_type": self.nb_instance_type})
241242

242243
if "mode" in kwargs:
243244
del kwargs["mode"]
@@ -270,7 +271,7 @@ def _build_for_djl_jumpstart(self):
270271
)
271272
self._prepare_for_mode()
272273
elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"):
273-
self.instance_type = self.instance_type or _get_nb_instance()
274+
self.nb_instance_type = self.instance_type or _get_nb_instance()
274275
self.pysdk_model.model_data, env = self._prepare_for_mode()
275276

276277
self.pysdk_model.env.update(env)
@@ -695,25 +696,29 @@ def _optimize_for_jumpstart(
695696
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
696697
)
697698

698-
optimization_env_vars = None
699-
pysdk_model_env_vars = None
700-
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
699+
if compilation_config:
700+
neuro_model_id = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
701+
"NeuronModelId"
702+
)
703+
self.model = neuro_model_id
704+
self.pysdk_model = self._create_pre_trained_js_model()
701705

702706
if speculative_decoding_config:
703707
self._set_additional_model_source(speculative_decoding_config)
704-
optimization_env_vars = self.pysdk_model.deployment_config.get(
705-
"DeploymentArgs", {}
706-
).get("Environment")
707708
else:
708709
deployment_config = self._find_compatible_deployment_config(None)
709710
if deployment_config:
710-
optimization_env_vars = deployment_config.get("DeploymentArgs").get("Environment")
711711
self.pysdk_model.set_deployment_config(
712712
config_name=deployment_config.get("DeploymentConfigName"),
713713
instance_type=deployment_config.get("InstanceType"),
714714
)
715715

716+
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
717+
optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
718+
"Environment"
719+
)
716720
optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars)
721+
pysdk_model_env_vars = env_vars
717722

718723
optimization_config = {}
719724
if quantization_config:
@@ -730,6 +735,10 @@ def _optimize_for_jumpstart(
730735
output_config = {"S3OutputLocation": output_path}
731736
if kms_key:
732737
output_config["KmsKeyId"] = kms_key
738+
if not instance_type:
739+
instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
740+
"InstanceType"
741+
)
733742

734743
create_optimization_job_args = {
735744
"OptimizationJobName": job_name,

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
6666
from sagemaker.serve.utils.optimize_utils import (
6767
_generate_optimized_model,
68-
_validate_optimization_inputs,
6968
)
7069
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
7170
from sagemaker.serve.utils.hardware_detector import (
@@ -238,7 +237,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
238237
metadata={"help": "Define the s3 location where you want to upload the model package"},
239238
)
240239
instance_type: Optional[str] = field(
241-
default="ml.c5.xlarge",
240+
default=None,
242241
metadata={"help": "Define the instance_type of the endpoint"},
243242
)
244243
schema_builder: Optional[SchemaBuilder] = field(
@@ -1022,9 +1021,8 @@ def _model_builder_optimize_wrapper(
10221021
Returns:
10231022
Model: A deployable ``Model`` object.
10241023
"""
1025-
_validate_optimization_inputs(
1026-
output_path, instance_type, quantization_config, compilation_config
1027-
)
1024+
if quantization_config and compilation_config:
1025+
raise ValueError("Quantization config and compilation config are mutually exclusive.")
10281026

10291027
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
10301028

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -160,42 +160,6 @@ def _extracts_and_validates_speculative_model_source(
160160
return s3_uri
161161

162162

163-
def _validate_optimization_inputs(
164-
output_path: Optional[str] = None,
165-
instance_type: Optional[str] = None,
166-
quantization_config: Optional[Dict] = None,
167-
compilation_config: Optional[Dict] = None,
168-
) -> None:
169-
"""Validates optimization inputs.
170-
171-
Args:
172-
output_path (Optional[str]): The output path.
173-
instance_type (Optional[str]): The instance type.
174-
quantization_config (Optional[Dict]): The quantization config.
175-
compilation_config (Optional[Dict]): The compilation config.
176-
177-
Raises:
178-
ValueError: If an optimization input is invalid.
179-
"""
180-
if quantization_config and compilation_config:
181-
raise ValueError("Quantization config and compilation config are mutually exclusive.")
182-
183-
instance_type_msg = "Please provide an instance type for %s optimization job."
184-
output_path_msg = "Please provide an output path for %s optimization job."
185-
186-
if quantization_config:
187-
if not instance_type:
188-
raise ValueError(instance_type_msg.format("quantization"))
189-
if not output_path:
190-
raise ValueError(output_path_msg.format("quantization"))
191-
192-
if compilation_config:
193-
if not instance_type:
194-
raise ValueError(instance_type_msg.format("compilation"))
195-
if not output_path:
196-
raise ValueError(output_path_msg.format("compilation"))
197-
198-
199163
def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str:
200164
"""Generates a channel name.
201165

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
_update_environment_variables,
2323
_is_image_compatible_with_optimization_job,
2424
_extract_speculative_draft_model_provider,
25-
_validate_optimization_inputs,
2625
_extracts_and_validates_speculative_model_source,
2726
_is_s3_uri,
2827
_generate_additional_model_data_sources,
@@ -168,31 +167,6 @@ def test_extract_speculative_draft_model_provider(
168167
)
169168

170169

171-
@pytest.mark.parametrize(
172-
"output_path, instance, quantization_config, compilation_config",
173-
[
174-
(
175-
None,
176-
None,
177-
{"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}},
178-
{"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}},
179-
),
180-
(None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, None),
181-
(None, None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}),
182-
("output_path", None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}),
183-
(None, "instance_type", None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}),
184-
],
185-
)
186-
def test_validate_optimization_inputs(
187-
output_path, instance, quantization_config, compilation_config
188-
):
189-
190-
with pytest.raises(ValueError):
191-
_validate_optimization_inputs(
192-
output_path, instance, quantization_config, compilation_config
193-
)
194-
195-
196170
def test_extract_speculative_draft_model_s3_uri():
197171
res = _extracts_and_validates_speculative_model_source({"ModelSource": "s3://"})
198172
assert res == "s3://"

0 commit comments

Comments
 (0)