Skip to content

Commit c41a7ca

Browse files
Jonathan Makungamufaddal-rohawala
authored andcommitted
Refactoring
1 parent 80fb96a commit c41a7ca

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
10981098
"gated_bucket",
10991099
"model_subscription_link",
11001100
"hosting_additional_data_sources",
1101+
"hosting_neuron_model_id",
1102+
"hosting_neuron_model_version",
11011103
]
11021104

11031105
def __init__(self, fields: Dict[str, Any]):
@@ -1208,6 +1210,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12081210
if json_obj.get("hosting_additional_data_sources")
12091211
else None
12101212
)
1213+
self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id")
1214+
self.hosting_neuron_model_version: Optional[str] = json_obj.get(
1215+
"hosting_neuron_model_version"
1216+
)
12111217

12121218
if self.training_supported:
12131219
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (
@@ -2569,8 +2575,6 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
25692575
"model_data_download_timeout",
25702576
"container_startup_health_check_timeout",
25712577
"additional_data_sources",
2572-
"neuron_model_id",
2573-
"neuron_model_version",
25742578
]
25752579

25762580
def __init__(
@@ -2601,8 +2605,6 @@ def __init__(
26012605
"supported_inference_instance_types"
26022606
)
26032607
self.additional_data_sources = resolved_config.get("hosting_additional_data_sources")
2604-
self.neuron_model_id = resolved_config.get("hosting_neuron_model_id")
2605-
self.neuron_model_version = resolved_config.get("hosting_neuron_model_version")
26062608

26072609

26082610
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def _is_jumpstart_model_id(self) -> bool:
139139

140140
def _create_pre_trained_js_model(self) -> Type[Model]:
141141
"""Placeholder docstring"""
142-
pysdk_model = JumpStartModel(
143-
self.model, vpc_config=self.vpc_config, instance_type=self.instance_type
144-
)
142+
pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config)
145143
pysdk_model.sagemaker_session = self.sagemaker_session
146144

147145
self._original_deploy = pysdk_model.deploy
@@ -696,12 +694,12 @@ def _optimize_for_jumpstart(
696694
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
697695
)
698696

697+
optimization_env_vars = env_vars
698+
pysdk_model_env_vars = env_vars
699+
699700
if compilation_config:
700-
neuro_model_id = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
701-
"NeuronModelId"
702-
)
703-
self.model = neuro_model_id
704-
self.pysdk_model = self._create_pre_trained_js_model()
701+
neuron_env = self._get_neuron_model_env_vars(instance_type)
702+
optimization_env_vars = _update_environment_variables(neuron_env, optimization_env_vars)
705703

706704
if speculative_decoding_config:
707705
self._set_additional_model_source(speculative_decoding_config)
@@ -714,11 +712,6 @@ def _optimize_for_jumpstart(
714712
)
715713

716714
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
717-
optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
718-
"Environment"
719-
)
720-
optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars)
721-
pysdk_model_env_vars = env_vars
722715

723716
optimization_config = {}
724717
if quantization_config:
@@ -874,3 +867,26 @@ def _find_compatible_deployment_config(
874867

875868
# fall back to the default jumpstart model deployment config for optimization job
876869
return self.pysdk_model.deployment_config
870+
871+
def _get_neuron_model_env_vars(
872+
self, instance_type: Optional[str] = None
873+
) -> Optional[Dict[str, Any]]:
874+
"""Gets Neuron model env vars.
875+
876+
Args:
877+
instance_type (Optional[str]): Instance type.
878+
879+
Returns:
880+
Optional[Dict[str, Any]]: Neuron Model environment variables.
881+
"""
882+
metadata_config = self.pysdk_model._metadata_configs.get(self.pysdk_model.config_name)
883+
resolve_config = metadata_config.resolved_config
884+
if instance_type not in resolve_config.get("supported_inference_instance_types", []):
885+
neuro_model_id = resolve_config.get("hosting_neuron_model_id")
886+
neuro_model_version = resolve_config.get("hosting_neuron_model_version")
887+
if neuro_model_id:
888+
job_model = JumpStartModel(
889+
neuro_model_id, model_version=neuro_model_version, vpc_config=self.vpc_config
890+
)
891+
return job_model.env
892+
return None

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
237237
metadata={"help": "Define the s3 location where you want to upload the model package"},
238238
)
239239
instance_type: Optional[str] = field(
240-
default=None,
240+
default="ml.c5.xlarge",
241241
metadata={"help": "Define the instance_type of the endpoint"},
242242
)
243243
schema_builder: Optional[SchemaBuilder] = field(
@@ -1055,6 +1055,6 @@ def _model_builder_optimize_wrapper(
10551055
if input_args:
10561056
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
10571057
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1058-
self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status)
1058+
return _generate_optimized_model(self.pysdk_model, job_status)
10591059

10601060
return self.pysdk_model

0 commit comments

Comments
 (0)